Skip to content

Commit

Permalink
Allow enabling/disabling state reinitialization for individual states (
Browse files Browse the repository at this point in the history
…Closes #1345)
  • Loading branch information
dweindl committed Nov 30, 2020
1 parent 39f9b6e commit 3f3ca2e
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 65 deletions.
18 changes: 11 additions & 7 deletions include/amici/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <sunmatrix/sunmatrix_sparse.h>

#include <memory>
#include <set>

namespace amici {

Expand Down Expand Up @@ -230,9 +231,11 @@ class AbstractModel {
* @param t initial time
* @param p parameter vector
* @param k constant vector
* @return set of indices of states that have been reset
*/
virtual void fx0_fixedParameters(realtype *x0, const realtype t,
const realtype *p, const realtype *k);
virtual std::set<int> fx0_fixedParameters(
realtype *x0, const realtype t,
const realtype *p, const realtype *k);

/**
* @brief Model specific implementation of fsx0_fixedParameters
Expand All @@ -242,10 +245,11 @@ class AbstractModel {
* @param p parameter vector
* @param k constant vector
* @param ip sensitivity index
* @param resettedStateIdxs set of indices of states have been reset
*/
virtual void fsx0_fixedParameters(realtype *sx0, const realtype t,
const realtype *x0, const realtype *p,
const realtype *k, int ip);
virtual void fsx0_fixedParameters(
realtype *sx0, const realtype t, const realtype *x0, const realtype *p,
const realtype *k, int ip, const std::set<int>& resettedStateIdxs);

/**
* @brief Model specific implementation of fsx0
Expand Down Expand Up @@ -626,7 +630,7 @@ class AbstractModel {
virtual void fdJydy(realtype *dJydy, int iy, const realtype *p,
const realtype *k, const realtype *y,
const realtype *sigmay, const realtype *my);

/**
* @brief Model-specific implementation of fdJydy colptrs
* @param dJydy sparse matrix to which colptrs will be written
Expand Down Expand Up @@ -800,7 +804,7 @@ class AbstractModel {
* @param dwdx sparse matrix to which rowvals will be written
*/
virtual void fdwdx_rowvals(SUNMatrixWrapper &dwdx);

/**
* @brief Model specific implementation of fdwdw, no w chainrule (Py)
* @param dwdw partial derivative w wrt w
Expand Down
30 changes: 17 additions & 13 deletions include/amici/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class Model : public AbstractModel {
using AbstractModel::fx0_fixedParameters;
using AbstractModel::fy;
using AbstractModel::fz;

/**
* @brief Initialize model properties.
* @param x Reference to state variables
Expand Down Expand Up @@ -597,7 +597,7 @@ class Model : public AbstractModel {
* @return Observable IDs
*/
virtual std::vector<std::string> getObservableIds() const;

/**
* @brief Checks whether the defined noise model is gaussian, i.e., the nllh is quadratic
* @return boolean flag
Expand Down Expand Up @@ -1204,8 +1204,9 @@ class Model : public AbstractModel {
* @brief Set only those initial states that are specified via
* fixed parameters.
* @param x Output buffer.
* @return set of indices of states that have been reset
*/
void fx0_fixedParameters(AmiVector &x);
std::set<int> fx0_fixedParameters(AmiVector &x);

/**
* @brief Compute/get initial value for initial state sensitivities.
Expand All @@ -1219,8 +1220,11 @@ class Model : public AbstractModel {
* from `amici::Model::fx0_fixedParameters`.
* @param sx Output buffer for state sensitivities
* @param x State variables
* @param resettedStateIdxs set of indices of states have been reset
*/
void fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x);
void fsx0_fixedParameters(AmiVectorArray &sx,
const AmiVector &x,
const std::set<int>& resettedStateIdxs);

/**
* @brief Compute sensitivity of derivative initial states sensitivities
Expand Down Expand Up @@ -1298,13 +1302,13 @@ class Model : public AbstractModel {

/** Flag indicating Matlab- or Python-based model generation */
bool pythonGenerated;

/**
* @brief getter for dxdotdp (matlab generated)
* @return dxdotdp
*/
const AmiVectorArray &get_dxdotdp() const;

/**
* @brief getter for dxdotdp (python generated)
* @return dxdotdp
Expand Down Expand Up @@ -1642,7 +1646,7 @@ class Model : public AbstractModel {
* @param x Array with the states
*/
void fdwdx(realtype t, const realtype *x);

/**
* @brief Compute self derivative for recurring terms in xdot.
* @param t Timepoint
Expand Down Expand Up @@ -1752,13 +1756,13 @@ class Model : public AbstractModel {

/** Sparse dwdx temporary storage (dimension: `ndwdx`) */
mutable SUNMatrixWrapper dwdx_;

/** Sparse dwdp temporary storage (dimension: `ndwdp`) */
mutable SUNMatrixWrapper dwdp_;

/** Dense Mass matrix (dimension: `nx_solver` x `nx_solver`) */
mutable SUNMatrixWrapper M_;

/**
* Temporary storage of `dxdotdp_full` data across functions (Python only)
* (dimension: `nplist` x `nx_solver`, nnz: dynamic,
Expand All @@ -1780,7 +1784,7 @@ class Model : public AbstractModel {
* type `CSC_MAT`)
*/
mutable SUNMatrixWrapper dxdotdp_implicit;

/**
* Temporary storage of `dxdotdx_explicit` data across functions (Python only)
* (dimension: `nplist` x `nx_solver`, nnz: 'nxdotdotdx_explicit',
Expand Down Expand Up @@ -2005,10 +2009,10 @@ class Model : public AbstractModel {

/** Sparse dwdw temporary storage (dimension: `ndwdw`) */
mutable SUNMatrixWrapper dwdw_;

/** Sparse dwdx implicit temporary storage (dimension: `ndwdx`) */
mutable std::vector<SUNMatrixWrapper> dwdx_hierarchical_;

/** Recursion */
int w_recursion_depth_ {0};
};
Expand Down
42 changes: 32 additions & 10 deletions python/amici/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
'signature':
'(realtype *x0_fixedParameters, const realtype t, '
'const realtype *p, const realtype *k)',
'ret_type': 'std::set<int>'
},
'sx0': {
'signature':
Expand All @@ -180,7 +181,7 @@
'signature':
'(realtype *sx0_fixedParameters, const realtype t, '
'const realtype *x0, const realtype *p, const realtype *k, '
'const int ip)',
'const int ip, const std::set<int> &resettedStateIdxs)',
},
'xdot': {
'signature':
Expand Down Expand Up @@ -2430,11 +2431,12 @@ def _write_function_file(self, function: str) -> None:
'#include "sundials/sundials_types.h"',
'',
'#include <array>',
'#include <set>',
]

# function signature
signature = self.functions[function]['signature']

ret_type = self.functions[function].get('ret_type', 'void')
lines.append('')

for sym in self.model.sym_names():
Expand All @@ -2452,7 +2454,7 @@ def _write_function_file(self, function: str) -> None:
'',
])

lines.append(f'void {function}_{self.model_name}{signature}{{')
lines.append(f'{ret_type} {function}_{self.model_name}{signature}{{')

# function body
body = self._get_function_body(function, equations)
Expand Down Expand Up @@ -2546,9 +2548,10 @@ def _write_function_index(self, function: str, indextype: str) -> None:
lines.append(' ' + ', '.join(map(str, values)))
lines.append("};")

ret_type = self.functions[function].get('ret_type', 'void')
lines.extend([
'',
f'void {function}_{indextype}_{self.model_name}{signature}{{',
f'{ret_type} {function}_{indextype}_{self.model_name}{signature}{{',
])

if len(values):
Expand Down Expand Up @@ -2630,18 +2633,34 @@ def _get_function_body(self,
):
if not formula.is_zero:
expressions.append(
f'if(resettedStateIdxs.find({index}) != '
'resettedStateIdxs.end()) '
f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
cases[ipar] = expressions
lines.extend(get_switch_statement('ip', cases, 1))

elif function == 'x0_fixedParameters':
lines.append("realtype tmp;\n"
"std::set<int> resettedStateIdxs;")
for index, formula in zip(
self.model._x0_fixedParameters_idx,
equations
):
lines.append(f'{function}[{index}] = '
f'{_print_with_exception(formula)};')
lines.append(f'tmp = {_print_with_exception(formula)};\n'
'if(!std::isnan(tmp)) '
f'{function}[{index}] = tmp;\n'
f'resettedStateIdxs.emplace({index});')
lines.append("return resettedStateIdxs;")
elif function == 'x0':
lines.append("realtype tmp;")

lines.extend([
f' tmp = {_print_with_exception(math)};'
f' if(!std::isnan(tmp)) {function}[{index}] = tmp;'
for index, math in enumerate(equations)
if not (math == 0 or math == 0.0)
])

elif function in sensi_functions:
cases = {ipar: _get_sym_lines_array(equations[:, ipar], function,
Expand Down Expand Up @@ -3007,8 +3026,9 @@ def get_function_extern_declaration(fun: str, name: str) -> str:
c++ function definition string
"""
return \
f'extern void {fun}_{name}{functions[fun]["signature"]};'
signature = functions[fun]["signature"]
ret_type = functions[fun].get('ret_type', 'void')
return f'extern {ret_type} {fun}_{name}{signature};'


def get_sunindex_extern_declaration(fun: str, name: str,
Expand Down Expand Up @@ -3051,11 +3071,12 @@ def get_model_override_implementation(fun: str, name: str) -> str:
"""
return \
'virtual void f{fun}{signature} override {{\n' \
'virtual {ret_type} f{fun}{signature} override {{\n' \
'{ind8}{fun}_{name}{eval_signature};\n' \
'{ind4}}}\n'.format(
ind4=' '*4,
ind8=' '*8,
ret_type=functions[fun].get('ret_type', 'void'),
fun=fun,
name=name,
signature=functions[fun]["signature"],
Expand Down Expand Up @@ -3086,11 +3107,12 @@ def get_sunindex_override_implementation(fun: str, name: str,
index_arg_eval = ', index' if fun in multiobs_functions else ''

return \
'virtual void f{fun}_{indextype}{signature} override {{\n' \
'virtual {ret_type} f{fun}_{indextype}{signature} override {{\n' \
'{ind8}{fun}_{indextype}_{name}{eval_signature};\n' \
'{ind4}}}\n'.format(
ind4=' '*4,
ind8=' '*8,
ret_type=functions[fun].get('ret_type', 'void'),
fun=fun,
indextype=indextype,
name=name,
Expand Down
7 changes: 5 additions & 2 deletions src/abstract_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ AbstractModel::isFixedParameterStateReinitializationAllowed() const
return false;
}

void
std::set<int>
AbstractModel::fx0_fixedParameters(realtype* /*x0*/,
const realtype /*t*/,
const realtype* /*p*/,
const realtype* /*k*/)
{
// no-op default implementation
return std::set<int>();
}

void
Expand All @@ -46,7 +47,9 @@ AbstractModel::fsx0_fixedParameters(realtype* /*sx0*/,
const realtype* /*x0*/,
const realtype* /*p*/,
const realtype* /*k*/,
const int /*ip*/)
const int /*ip*/,
const std::set<int>& /*resettedStateIdxs*/
)
{
// no-op default implementation
}
Expand Down
26 changes: 15 additions & 11 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,19 +1149,21 @@ void Model::fx0(AmiVector &x) {
}
}

void Model::fx0_fixedParameters(AmiVector &x) {
std::set<int>Model::fx0_fixedParameters(AmiVector &x) {
if (!getReinitializeFixedParameterInitialStates())
return;
return std::set<int>();
/* we transform to the unreduced states x_rdata and then apply
x0_fixedparameters to (i) enable updates to states that were removed from
conservation laws and (ii) be able to correctly compute total abundances
after updating the state variables */
fx_rdata(x_rdata_.data(), x.data(), state_.total_cl.data());
fx0_fixedParameters(x_rdata_.data(), tstart_, state_.unscaledParameters.data(),
state_.fixedParameters.data());
auto resettedStateIdxs = fx0_fixedParameters(
x_rdata_.data(), tstart_, state_.unscaledParameters.data(),
state_.fixedParameters.data());
fx_solver(x.data(), x_rdata_.data());
/* update total abundances */
ftotal_cl(state_.total_cl.data(), x_rdata_.data());
return resettedStateIdxs;
}

void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) {
Expand All @@ -1178,18 +1180,20 @@ void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) {
}
}

void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) {
void Model::fsx0_fixedParameters(AmiVectorArray &sx,
const AmiVector &x,
const std::set<int>& resettedStateIdxs) {
if (!getReinitializeFixedParameterInitialStates())
return;
realtype *stcl = nullptr;
for (int ip = 0; ip < nplist(); ip++) {
if (ncl() > 0)
stcl = &state_.stotal_cl.at(plist(ip) * ncl());
fsx_rdata(sx_rdata_.data(), sx.data(ip), stcl, plist(ip));
fsx0_fixedParameters(sx_rdata_.data(), tstart_, x.data(),
state_.unscaledParameters.data(),
fsx0_fixedParameters(sx_rdata_.data(), tstart_,
x.data(), state_.unscaledParameters.data(),
state_.fixedParameters.data(),
plist(ip));
plist(ip), resettedStateIdxs);
fsx_solver(sx.data(ip), sx_rdata_.data());
fstotal_cl(stcl, sx_rdata_.data(), plist(ip));
}
Expand Down Expand Up @@ -1855,7 +1859,7 @@ void Model::fw(const realtype t, const realtype *x) {
void Model::fdwdp(const realtype t, const realtype *x) {
if (!nw)
return;

fw(t, x);
dwdp_.zero();
if (pythonGenerated) {
Expand Down Expand Up @@ -1896,7 +1900,7 @@ void Model::fdwdx(const realtype t, const realtype *x) {
return;

fw(t, x);

dwdx_.zero();
if (pythonGenerated) {
if (!dwdx_hierarchical_.at(0).capacity())
Expand Down Expand Up @@ -1939,7 +1943,7 @@ void Model::fdwdw(const realtype t, const realtype *x) {
fdwdw(dwdw_.data(), t, x, state_.unscaledParameters.data(),
state_.fixedParameters.data(), state_.h.data(), w_.data(),
state_.total_cl.data());

if (always_check_finite_) {
app->checkFinite(gsl::make_span(dwdw_.get()), "dwdw");
}
Expand Down
Loading

0 comments on commit 3f3ca2e

Please sign in to comment.