Skip to content

Commit

Permalink
feat: Use SUNDIALS v7 on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
vasylskorych committed Sep 23, 2024
1 parent 5d2b90d commit b21faa7
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 51 deletions.
64 changes: 41 additions & 23 deletions EquationSolvers/DAESolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool CDAESolver::SetModel(CDAEModel* _model)
return true;
}

bool CDAESolver::Calculate(realtype _time)
bool CDAESolver::Calculate(sun_real _time)
{
if (_time == 0.0)
{
Expand All @@ -70,7 +70,7 @@ bool CDAESolver::Calculate(realtype _time)
return true;
}

bool CDAESolver::Calculate(realtype _timeBeg, realtype _timeEnd)
bool CDAESolver::Calculate(sun_real _timeBeg, sun_real _timeEnd)
{
int res;

Expand All @@ -94,7 +94,7 @@ bool CDAESolver::Calculate(realtype _timeBeg, realtype _timeEnd)
return WriteError("IDA", "IDASetMaxStep", "Cannot set maximum absolute step size");

/* Get current integration step.*/
realtype currStep;
sun_real currStep;
res = IDAGetCurrentStep(m_solverMem.idamem, &currStep);
if (res != IDA_SUCCESS)
return WriteError("IDA", "IDAGetCurrentStep", "Cannot read current time step.");
Expand Down Expand Up @@ -140,7 +140,7 @@ bool CDAESolver::CalculateInitialConditions()
return true;
}

bool CDAESolver::IntegrateUntil(realtype _time)
bool CDAESolver::IntegrateUntil(sun_real _time)
{
/* set integration limit */
int res = IDASetStopTime(m_solverMem.idamem, _time);
Expand All @@ -166,12 +166,12 @@ void CDAESolver::SaveState()
const size_t len = m_model->GetVariablesNumber();
const auto* src = static_cast<IDAMem>(m_solverMem.idamem);

std::memcpy(m_solverMem_store.vars.data(), N_VGetArrayPointer(m_solverMem.vars), sizeof(realtype) * len);
std::memcpy(m_solverMem_store.ders.data(), N_VGetArrayPointer(m_solverMem.ders), sizeof(realtype) * len);
std::memcpy(m_solverMem_store.vars.data(), N_VGetArrayPointer(m_solverMem.vars), sizeof(sun_real) * len);
std::memcpy(m_solverMem_store.ders.data(), N_VGetArrayPointer(m_solverMem.ders), sizeof(sun_real) * len);

for (size_t i = 0; i < MXORDP1; ++i)
std::memcpy(m_solverMem_store.ida_phi[i].data(), N_VGetArrayPointer(src->ida_phi[i]), sizeof(realtype) * len);
std::memcpy(m_solverMem_store.ida_psi.data(), src->ida_psi, sizeof(realtype) * MXORDP1);
std::memcpy(m_solverMem_store.ida_phi[i].data(), N_VGetArrayPointer(src->ida_phi[i]), sizeof(sun_real) * len);
std::memcpy(m_solverMem_store.ida_psi.data(), src->ida_psi, sizeof(sun_real) * MXORDP1);
m_solverMem_store.ida_kused = src->ida_kused;
m_solverMem_store.ida_ns = src->ida_ns;
m_solverMem_store.ida_hh = src->ida_hh;
Expand All @@ -187,12 +187,12 @@ void CDAESolver::LoadState() const
const size_t len = m_model->GetVariablesNumber();
auto* dst = static_cast<IDAMem>(m_solverMem.idamem);

std::memcpy(N_VGetArrayPointer(m_solverMem.vars), m_solverMem_store.vars.data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(m_solverMem.ders), m_solverMem_store.ders.data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(m_solverMem.vars), m_solverMem_store.vars.data(), sizeof(sun_real) * len);
std::memcpy(N_VGetArrayPointer(m_solverMem.ders), m_solverMem_store.ders.data(), sizeof(sun_real) * len);

for (size_t i = 0; i < MXORDP1; ++i)
std::memcpy(N_VGetArrayPointer(dst->ida_phi[i]), m_solverMem_store.ida_phi[i].data(), sizeof(realtype) * len);
std::memcpy(dst->ida_psi, m_solverMem_store.ida_psi.data(), sizeof(realtype) * MXORDP1);
std::memcpy(N_VGetArrayPointer(dst->ida_phi[i]), m_solverMem_store.ida_phi[i].data(), sizeof(sun_real) * len);
std::memcpy(dst->ida_psi, m_solverMem_store.ida_psi.data(), sizeof(sun_real) * MXORDP1);
dst->ida_kused = m_solverMem_store.ida_kused;
dst->ida_ns = m_solverMem_store.ida_ns;
dst->ida_hh = m_solverMem_store.ida_hh;
Expand Down Expand Up @@ -231,8 +231,12 @@ bool CDAESolver::InitSolverMemory(SSolverMemory& _mem)
int res; // return value

// create context
#if SUNDIALS_VERSION_MAJOR >= 6
#if SUNDIALS_VERSION_MAJOR == 6
res = SUNContext_Create(nullptr, &_mem.sunctx);
#else
res = SUNContext_Create(SUN_COMM_NULL, &_mem.sunctx);
#endif
#if SUNDIALS_VERSION_MAJOR >= 6
if (res != IDA_SUCCESS)
return WriteError("IDA", "SUNContext_Create", "Cannot create SUNDIALS context.");
#endif
Expand All @@ -249,11 +253,11 @@ bool CDAESolver::InitSolverMemory(SSolverMemory& _mem)
return WriteError("IDA", "N_VNew_Serial", "Cannot create vectors.");

// initialize vectors
std::memcpy(N_VGetArrayPointer(_mem.vars) , m_model->GetVarInitValues() .data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(_mem.ders) , m_model->GetDerInitValues() .data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(_mem.atols) , m_model->GetATols() .data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(_mem.types) , m_model->GetVarTypes() .data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(_mem.constr), m_model->GetConstraintValues().data(), sizeof(realtype) * len);
std::memcpy(N_VGetArrayPointer(_mem.vars) , m_model->GetVarInitValues() .data(), sizeof(sun_real) * len);
std::memcpy(N_VGetArrayPointer(_mem.ders) , m_model->GetDerInitValues() .data(), sizeof(sun_real) * len);
std::memcpy(N_VGetArrayPointer(_mem.atols) , m_model->GetATols() .data(), sizeof(sun_real) * len);
std::memcpy(N_VGetArrayPointer(_mem.types) , m_model->GetVarTypes() .data(), sizeof(sun_real) * len);
std::memcpy(N_VGetArrayPointer(_mem.constr), m_model->GetConstraintValues().data(), sizeof(sun_real) * len);

// create matrix object
_mem.sunmatr = SUNDenseMatrix(len, len MAYBE_COMMA_CONTEXT(m_solverMem));
Expand Down Expand Up @@ -287,9 +291,15 @@ bool CDAESolver::InitSolverMemory(SSolverMemory& _mem)

// set optional inputs
// set error handler function
#if SUNDIALS_VERSION_MAJOR < 7
res = IDASetErrHandlerFn(_mem.idamem, &CDAESolver::ErrorHandler, &m_errorMessage);
if (res != IDA_SUCCESS)
return WriteError("IDA", "IDASetErrHandlerFn", "Cannot setup error handler function.");
#else
res = SUNContext_PushErrHandler(_mem.sunctx, &CDAESolver::ErrorHandler, &m_errorMessage);
if (res != IDA_SUCCESS)
return WriteError("IDA", "SUNContext_PushErrHandler", "Cannot setup error handler function.");
#endif
// set model as user data
res = IDASetUserData(_mem.idamem, m_model);
if (res != IDA_SUCCESS)
Expand Down Expand Up @@ -330,7 +340,7 @@ void CDAESolver::InitStoreMemory(SStoreMemory& _mem) const
const auto len = static_cast<sunindextype>(m_model->GetVariablesNumber());
_mem.vars.resize(len);
_mem.ders.resize(len);
_mem.ida_phi.resize(MXORDP1, std::vector<realtype>(len));
_mem.ida_phi.resize(MXORDP1, std::vector<sun_real>(len));
_mem.ida_psi.resize(MXORDP1);
}

Expand All @@ -342,11 +352,11 @@ void CDAESolver::Clear()
ClearSolverMemory(m_solverMem);
}

int CDAESolver::ResidualFunction(realtype _time, N_Vector _vals, N_Vector _ders, N_Vector _ress, void* _model)
int CDAESolver::ResidualFunction(sun_real _time, N_Vector _vals, N_Vector _ders, N_Vector _ress, void* _model)
{
realtype* vals = N_VGetArrayPointer(_vals);
realtype* ders = N_VGetArrayPointer(_ders);
realtype* ress = N_VGetArrayPointer(_ress);
sun_real* vals = N_VGetArrayPointer(_vals);
sun_real* ders = N_VGetArrayPointer(_ders);
sun_real* ress = N_VGetArrayPointer(_ress);
const bool res = static_cast<CDAEModel*>(_model)->GetResiduals(_time, vals, ders, ress);
return res ? 0 : -1;
}
Expand All @@ -359,6 +369,14 @@ void CDAESolver::ErrorHandler(int _errorCode, const char* _module, const char* _
AppendMessage(_module, _function, _message, out);
}

void CDAESolver::ErrorHandler(int _line, const char* _function, const char* _file, const char* _message, SUNErrCode _errCode, void* _outString, [[maybe_unused]] SUNContext _sunctx)
{
if (!_outString) return;
if (!_errCode) return;
std::string& out = *static_cast<std::string*>(_outString);
AppendMessage(_file, _function, _message, out);
}

std::string CDAESolver::BuildErrorMessage(const std::string& _module, const std::string& _function, const std::string& _message)
{
return "[" + _module + " ERROR] in " + _function + ": " + _message;
Expand Down
44 changes: 31 additions & 13 deletions EquationSolvers/DAESolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
/** Solver of differential algebraic equations. Uses IDA solver from SUNDIALS package*/
class CDAESolver
{
#if SUNDIALS_VERSION_MAJOR <= 6
using sun_real = realtype;
#else
using sun_real = sunrealtype;
#endif

/** Memory needed for solver. */
struct SSolverMemory
{
Expand All @@ -29,15 +35,15 @@ class CDAESolver
/** Data field from IDA memory needed to be temporary stored. */
struct SStoreMemory
{
std::vector<realtype> vars;
std::vector<realtype> ders;
std::vector<std::vector<realtype>> ida_phi;
std::vector<realtype> ida_psi;
std::vector<sun_real> vars;
std::vector<sun_real> ders;
std::vector<std::vector<sun_real>> ida_phi;
std::vector<sun_real> ida_psi;
int ida_kused;
int ida_ns;
realtype ida_hh;
realtype ida_tn;
realtype ida_cj;
sun_real ida_hh;
sun_real ida_tn;
sun_real ida_cj;
long int ida_nst;
};

Expand All @@ -46,8 +52,8 @@ class CDAESolver
SSolverMemory m_solverMem{}; ///< Solver-specific memory.
SStoreMemory m_solverMem_store{}; ///< Solver-specific memory for temporary storing.

realtype m_timeLast{}; ///< Last calculated time point.
realtype m_maxStep{}; ///< Maximum iteration time step.
sun_real m_timeLast{}; ///< Last calculated time point.
sun_real m_maxStep{}; ///< Maximum iteration time step.
size_t m_maxNumSteps{ 500 }; ///< Maximum number of allowed solver iterations.

std::string m_errorMessage; ///< Text description of the occurred errors.
Expand All @@ -66,20 +72,20 @@ class CDAESolver
/** Solve problem on a given time point.
* \param _time Time point.
* \retval true No errors occurred. */
bool Calculate(realtype _time);
bool Calculate(sun_real _time);
/** Solve problem on a given time interval.
* \param _timeBeg Start of the time interval.
* \param _timeEnd End of the time interval.
* \retval true No errors occurred. */
bool Calculate(realtype _timeBeg, realtype _timeEnd);
bool Calculate(sun_real _timeBeg, sun_real _timeEnd);

/** Calculates and applies corrected initial conditions.
* \retval true No errors occurred. */
bool CalculateInitialConditions();
/** Integrates the problem until the given time point.
* \param _time Final time of integration.
* \retval true No errors occurred. */
bool IntegrateUntil(realtype _time);
bool IntegrateUntil(sun_real _time);

/** Save current state of solver.
* Should be called during saving of unit. */
Expand Down Expand Up @@ -128,7 +134,7 @@ class CDAESolver
* \param _ress Output residual vector F(t, y, y').
* \param _model Pointer to a DAE model.
* \return Error code. */
static int ResidualFunction(realtype _time, N_Vector _vals, N_Vector _ders, N_Vector _ress, void *_model);
static int ResidualFunction(sun_real _time, N_Vector _vals, N_Vector _ders, N_Vector _ress, void *_model);

/** A callback function called by the solver to handle internal errors.
* \param _errorCode Error code
Expand All @@ -137,6 +143,18 @@ class CDAESolver
* \param _message The error message
* \param _outString Pointer to a string to put error message*/
static void ErrorHandler(int _errorCode, const char* _module, const char* _function, char* _message, void* _outString);
/**
* \brief A callback function called by the solver to handle internal errors.
* \details A version for SUNDIALS 7+.
* \param _line The line number at which the error occured.
* \param _function The function in which the error occured.
* \param _file The file in which the error occured.
* \param _message The error message.
* \param _errCode The error code for the error that occured.
* \param _outString Pointer to a string to put error message.
* \param _sunctx Pointer to a valid SUNContext object.
*/
static void ErrorHandler(int _line, const char* _function, const char* _file, const char* _message, SUNErrCode _errCode, void* _outString, SUNContext _sunctx);
/** Builds an error message from its parts.
* \param _module Name of the module reporting the error
* \param _function Name of the function in which the error occurred
Expand Down
51 changes: 42 additions & 9 deletions EquationSolvers/NLSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,18 @@ bool CNLSolver::SetModel(CNLModel* _pModel)
{
ClearMemory();

#if SUNDIALS_VERSION_MAJOR < 6
// create context
#if SUNDIALS_VERSION_MAJOR == 6
SUNErrCode res = SUNContext_Create(nullptr, &m_sunctx);
#else
SUNContext_Create(nullptr, &m_sunctx);
SUNErrCode res = SUNContext_Create(SUN_COMM_NULL, &m_sunctx);
#endif
#if SUNDIALS_VERSION_MAJOR >= 6
if (!res)
{
ErrorHandler(-1, "KIN", "SUNContext_Create", "Cannot create SUNDIALS context.", &m_sErrorDescription);
return false;
}
#endif

m_pModel = _pModel;
Expand All @@ -136,11 +145,22 @@ bool CNLSolver::SetModel(CNLModel* _pModel)
ErrorHandler(-1, "KIN", "KINCreate", "Cannot allocate memory for solver.", &m_sErrorDescription);
return false;
}
if (KINSetErrHandlerFn(m_pKINmem, &CNLSolver::ErrorHandler, &m_sErrorDescription) != KIN_SUCCESS)
// set error handler function
#if SUNDIALS_VERSION_MAJOR < 7
res = KINSetErrHandlerFn(m_pKINmem, &CNLSolver::ErrorHandler, &m_sErrorDescription));
if (res != KIN_SUCCESS)
{
ErrorHandler(-1, "KIN", "KINSetErrHandlerFn", "Cannot setup error handler function.", &m_sErrorDescription);
return false;
}
#else
res = SUNContext_PushErrHandler(m_sunctx, &CNLSolver::ErrorHandler, &m_sErrorDescription);
if (res != KIN_SUCCESS)
{
ErrorHandler(-1, "KIN", "SUNContext_PushErrHandler", "Cannot setup error handler function.", &m_sErrorDescription);
return false;
}
#endif

const sunindextype nVarsCnt = static_cast<sunindextype>(m_pModel->GetVariablesNumber());
m_pModel->SetStrategy(m_eStrategy);
Expand Down Expand Up @@ -205,8 +225,8 @@ bool CNLSolver::SetModel(CNLModel* _pModel)
else
{
KINSetMAA(m_pKINmem, m_nMAA);
KINSetDampingAA(m_pKINmem, static_cast<realtype>(m_dDampingAA));
KINSetDamping(m_pKINmem, static_cast<realtype>(m_dDamping));
KINSetDampingAA(m_pKINmem, static_cast<sun_real>(m_dDampingAA));
KINSetDamping(m_pKINmem, static_cast<sun_real>(m_dDamping));
}

// Initialize IDA memory
Expand Down Expand Up @@ -238,7 +258,7 @@ bool CNLSolver::SetModel(CNLModel* _pModel)
return true;
}

bool CNLSolver::Calculate(realtype _dTime)
bool CNLSolver::Calculate(sun_real _dTime)
{
const int ret = KINSol(m_pKINmem, m_vectorVars, (int)E2I(m_eStrategy), m_vectorUScales, m_vectorFScales);

Expand Down Expand Up @@ -267,8 +287,8 @@ std::string CNLSolver::GetError() const

int CNLSolver::ResidualFunction(N_Vector _value, N_Vector _func, void *_pModel)
{
realtype *pValue = NV_DATA_S(_value);
realtype *pFunc = NV_DATA_S(_func);
sun_real *pValue = NV_DATA_S(_value);
sun_real *pFunc = NV_DATA_S(_func);

const bool bRes = static_cast<CNLModel*>(_pModel)->GetFunctions(pValue, pFunc);

Expand Down Expand Up @@ -302,7 +322,7 @@ void CNLSolver::ClearMemory()
void CNLSolver::CopyNVector(N_Vector _dst, N_Vector _src)
{
if (_dst == nullptr || _src == nullptr) return;
std::memcpy(NV_DATA_S(_dst), NV_DATA_S(_src), sizeof(realtype)*static_cast<size_t>(NV_LENGTH_S(_src)));
std::memcpy(NV_DATA_S(_dst), NV_DATA_S(_src), sizeof(sun_real)*static_cast<size_t>(NV_LENGTH_S(_src)));
}

void CNLSolver::ErrorHandler(int _nErrorCode, const char *_pModule, const char *_pFunction, char *_pMsg, void *_sOutString)
Expand All @@ -321,6 +341,19 @@ void CNLSolver::ErrorHandler(int _nErrorCode, const char *_pModule, const char *
}
}

void CNLSolver::ErrorHandler(int _line, const char* _function, const char* _file, const char* _message, SUNErrCode _errCode, void* _outString, SUNContext _sunctx)
{
if (!_outString) return;
if (!_errCode) return;
std::string description = "[";
description += _file;
description += " ERROR] in ";
description += _function;
description += ": ";
description += _message;
*static_cast<std::string*>(_outString) = description;
}

#ifdef _MSC_VER
#else
#pragma GCC diagnostic pop
Expand Down
Loading

0 comments on commit b21faa7

Please sign in to comment.