From 3f3ca2eeb8c233a36ba2d6fbfd0241346207dd06 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 30 Nov 2020 22:44:07 +0100 Subject: [PATCH] Allow enabling/disabling state reinitialization for individual states (Closes #1345) --- include/amici/abstract_model.h | 18 ++++++---- include/amici/model.h | 30 +++++++++------- python/amici/ode_export.py | 42 ++++++++++++++++------ src/abstract_model.cpp | 7 ++-- src/model.cpp | 26 ++++++++------ src/model_header.ODE_template.h | 15 ++++---- src/solver.cpp | 14 ++++---- src/steadystateproblem.cpp | 10 +++--- tests/petab_test_suite/test_petab_suite.py | 7 ++-- 9 files changed, 104 insertions(+), 65 deletions(-) diff --git a/include/amici/abstract_model.h b/include/amici/abstract_model.h index e624fec402..3072fe0d94 100644 --- a/include/amici/abstract_model.h +++ b/include/amici/abstract_model.h @@ -10,6 +10,7 @@ #include #include +#include namespace amici { @@ -230,9 +231,11 @@ class AbstractModel { * @param t initial time * @param p parameter vector * @param k constant vector + * @return set of indices of states that have been reset */ - virtual void fx0_fixedParameters(realtype *x0, const realtype t, - const realtype *p, const realtype *k); + virtual std::set fx0_fixedParameters( + realtype *x0, const realtype t, + const realtype *p, const realtype *k); /** * @brief Model specific implementation of fsx0_fixedParameters @@ -242,10 +245,11 @@ class AbstractModel { * @param p parameter vector * @param k constant vector * @param ip sensitivity index + * @param resettedStateIdxs set of indices of states have been reset */ - virtual void fsx0_fixedParameters(realtype *sx0, const realtype t, - const realtype *x0, const realtype *p, - const realtype *k, int ip); + virtual void fsx0_fixedParameters( + realtype *sx0, const realtype t, const realtype *x0, const realtype *p, + const realtype *k, int ip, const std::set& resettedStateIdxs); /** * @brief Model specific implementation of fsx0 @@ -626,7 +630,7 @@ class AbstractModel { virtual void fdJydy(realtype *dJydy, int iy, const realtype *p, const realtype *k, const realtype *y, const realtype *sigmay, const realtype *my); - + /** * @brief Model-specific implementation of fdJydy colptrs * @param dJydy sparse matrix to which colptrs will be written @@ -800,7 +804,7 @@ class AbstractModel { * @param dwdx sparse matrix to which rowvals will be written */ virtual void fdwdx_rowvals(SUNMatrixWrapper &dwdx); - + /** * @brief Model specific implementation of fdwdw, no w chainrule (Py) * @param dwdw partial derivative w wrt w diff --git a/include/amici/model.h b/include/amici/model.h index 9b763f1740..6ac5d8367a 100644 --- a/include/amici/model.h +++ b/include/amici/model.h @@ -210,7 +210,7 @@ class Model : public AbstractModel { using AbstractModel::fx0_fixedParameters; using AbstractModel::fy; using AbstractModel::fz; - + /** * @brief Initialize model properties. * @param x Reference to state variables @@ -597,7 +597,7 @@ class Model : public AbstractModel { * @return Observable IDs */ virtual std::vector getObservableIds() const; - + /** * @brief Checks whether the defined noise model is gaussian, i.e., the nllh is quadratic * @return boolean flag @@ -1204,8 +1204,9 @@ class Model : public AbstractModel { * @brief Set only those initial states that are specified via * fixed parameters. * @param x Output buffer. + * @return set of indices of states that have been reset */ - void fx0_fixedParameters(AmiVector &x); + std::set fx0_fixedParameters(AmiVector &x); /** * @brief Compute/get initial value for initial state sensitivities. @@ -1219,8 +1220,11 @@ class Model : public AbstractModel { * from `amici::Model::fx0_fixedParameters`. * @param sx Output buffer for state sensitivities * @param x State variables + * @param resettedStateIdxs set of indices of states have been reset */ - void fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x); + void fsx0_fixedParameters(AmiVectorArray &sx, + const AmiVector &x, + const std::set& resettedStateIdxs); /** * @brief Compute sensitivity of derivative initial states sensitivities @@ -1298,13 +1302,13 @@ class Model : public AbstractModel { /** Flag indicating Matlab- or Python-based model generation */ bool pythonGenerated; - + /** * @brief getter for dxdotdp (matlab generated) * @return dxdotdp */ const AmiVectorArray &get_dxdotdp() const; - + /** * @brief getter for dxdotdp (python generated) * @return dxdotdp @@ -1642,7 +1646,7 @@ class Model : public AbstractModel { * @param x Array with the states */ void fdwdx(realtype t, const realtype *x); - + /** * @brief Compute self derivative for recurring terms in xdot. * @param t Timepoint @@ -1752,13 +1756,13 @@ class Model : public AbstractModel { /** Sparse dwdx temporary storage (dimension: `ndwdx`) */ mutable SUNMatrixWrapper dwdx_; - + /** Sparse dwdp temporary storage (dimension: `ndwdp`) */ mutable SUNMatrixWrapper dwdp_; - + /** Dense Mass matrix (dimension: `nx_solver` x `nx_solver`) */ mutable SUNMatrixWrapper M_; - + /** * Temporary storage of `dxdotdp_full` data across functions (Python only) * (dimension: `nplist` x `nx_solver`, nnz: dynamic, @@ -1780,7 +1784,7 @@ class Model : public AbstractModel { * type `CSC_MAT`) */ mutable SUNMatrixWrapper dxdotdp_implicit; - + /** * Temporary storage of `dxdotdx_explicit` data across functions (Python only) * (dimension: `nplist` x `nx_solver`, nnz: 'nxdotdotdx_explicit', @@ -2005,10 +2009,10 @@ class Model : public AbstractModel { /** Sparse dwdw temporary storage (dimension: `ndwdw`) */ mutable SUNMatrixWrapper dwdw_; - + /** Sparse dwdx implicit temporary storage (dimension: `ndwdx`) */ mutable std::vector dwdx_hierarchical_; - + /** Recursion */ int w_recursion_depth_ {0}; }; diff --git a/python/amici/ode_export.py b/python/amici/ode_export.py index 0724049724..2e2debeea6 100644 --- a/python/amici/ode_export.py +++ b/python/amici/ode_export.py @@ -170,6 +170,7 @@ 'signature': '(realtype *x0_fixedParameters, const realtype t, ' 'const realtype *p, const realtype *k)', + 'ret_type': 'std::set' }, 'sx0': { 'signature': @@ -180,7 +181,7 @@ 'signature': '(realtype *sx0_fixedParameters, const realtype t, ' 'const realtype *x0, const realtype *p, const realtype *k, ' - 'const int ip)', + 'const int ip, const std::set &resettedStateIdxs)', }, 'xdot': { 'signature': @@ -2430,11 +2431,12 @@ def _write_function_file(self, function: str) -> None: '#include "sundials/sundials_types.h"', '', '#include ', + '#include ', ] # function signature signature = self.functions[function]['signature'] - + ret_type = self.functions[function].get('ret_type', 'void') lines.append('') for sym in self.model.sym_names(): @@ -2452,7 +2454,7 @@ def _write_function_file(self, function: str) -> None: '', ]) - lines.append(f'void {function}_{self.model_name}{signature}{{') + lines.append(f'{ret_type} {function}_{self.model_name}{signature}{{') # function body body = self._get_function_body(function, equations) @@ -2546,9 +2548,10 @@ def _write_function_index(self, function: str, indextype: str) -> None: lines.append(' ' + ', '.join(map(str, values))) lines.append("};") + ret_type = self.functions[function].get('ret_type', 'void') lines.extend([ '', - f'void {function}_{indextype}_{self.model_name}{signature}{{', + f'{ret_type} {function}_{indextype}_{self.model_name}{signature}{{', ]) if len(values): @@ -2630,18 +2633,34 @@ def _get_function_body(self, ): if not formula.is_zero: expressions.append( + f'if(resettedStateIdxs.find({index}) != ' + 'resettedStateIdxs.end()) ' f'{function}[{index}] = ' f'{_print_with_exception(formula)};') cases[ipar] = expressions lines.extend(get_switch_statement('ip', cases, 1)) elif function == 'x0_fixedParameters': + lines.append("realtype tmp;\n" + "std::set resettedStateIdxs;") for index, formula in zip( self.model._x0_fixedParameters_idx, equations ): - lines.append(f'{function}[{index}] = ' - f'{_print_with_exception(formula)};') + lines.append(f'tmp = {_print_with_exception(formula)};\n' + 'if(!std::isnan(tmp)) ' + f'{function}[{index}] = tmp;\n' + f'resettedStateIdxs.emplace({index});') + lines.append("return resettedStateIdxs;") + elif function == 'x0': + lines.append("realtype tmp;") + + lines.extend([ + f' tmp = {_print_with_exception(math)};' + f' if(!std::isnan(tmp)) {function}[{index}] = tmp;' + for index, math in enumerate(equations) + if not (math == 0 or math == 0.0) + ]) elif function in sensi_functions: cases = {ipar: _get_sym_lines_array(equations[:, ipar], function, @@ -3007,8 +3026,9 @@ def get_function_extern_declaration(fun: str, name: str) -> str: c++ function definition string """ - return \ - f'extern void {fun}_{name}{functions[fun]["signature"]};' + signature = functions[fun]["signature"] + ret_type = functions[fun].get('ret_type', 'void') + return f'extern {ret_type} {fun}_{name}{signature};' def get_sunindex_extern_declaration(fun: str, name: str, @@ -3051,11 +3071,12 @@ def get_model_override_implementation(fun: str, name: str) -> str: """ return \ - 'virtual void f{fun}{signature} override {{\n' \ + 'virtual {ret_type} f{fun}{signature} override {{\n' \ '{ind8}{fun}_{name}{eval_signature};\n' \ '{ind4}}}\n'.format( ind4=' '*4, ind8=' '*8, + ret_type=functions[fun].get('ret_type', 'void'), fun=fun, name=name, signature=functions[fun]["signature"], @@ -3086,11 +3107,12 @@ def get_sunindex_override_implementation(fun: str, name: str, index_arg_eval = ', index' if fun in multiobs_functions else '' return \ - 'virtual void f{fun}_{indextype}{signature} override {{\n' \ + 'virtual {ret_type} f{fun}_{indextype}{signature} override {{\n' \ '{ind8}{fun}_{indextype}_{name}{eval_signature};\n' \ '{ind4}}}\n'.format( ind4=' '*4, ind8=' '*8, + ret_type=functions[fun].get('ret_type', 'void'), fun=fun, indextype=indextype, name=name, diff --git a/src/abstract_model.cpp b/src/abstract_model.cpp index 59f0d45353..ffd16cf59f 100644 --- a/src/abstract_model.cpp +++ b/src/abstract_model.cpp @@ -31,13 +31,14 @@ AbstractModel::isFixedParameterStateReinitializationAllowed() const return false; } -void +std::set AbstractModel::fx0_fixedParameters(realtype* /*x0*/, const realtype /*t*/, const realtype* /*p*/, const realtype* /*k*/) { // no-op default implementation + return std::set(); } void @@ -46,7 +47,9 @@ AbstractModel::fsx0_fixedParameters(realtype* /*sx0*/, const realtype* /*x0*/, const realtype* /*p*/, const realtype* /*k*/, - const int /*ip*/) + const int /*ip*/, + const std::set& /*resettedStateIdxs*/ + ) { // no-op default implementation } diff --git a/src/model.cpp b/src/model.cpp index 0bda2bf9af..44bd3066b1 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1149,19 +1149,21 @@ void Model::fx0(AmiVector &x) { } } -void Model::fx0_fixedParameters(AmiVector &x) { +std::setModel::fx0_fixedParameters(AmiVector &x) { if (!getReinitializeFixedParameterInitialStates()) - return; + return std::set(); /* we transform to the unreduced states x_rdata and then apply x0_fixedparameters to (i) enable updates to states that were removed from conservation laws and (ii) be able to correctly compute total abundances after updating the state variables */ fx_rdata(x_rdata_.data(), x.data(), state_.total_cl.data()); - fx0_fixedParameters(x_rdata_.data(), tstart_, state_.unscaledParameters.data(), - state_.fixedParameters.data()); + auto resettedStateIdxs = fx0_fixedParameters( + x_rdata_.data(), tstart_, state_.unscaledParameters.data(), + state_.fixedParameters.data()); fx_solver(x.data(), x_rdata_.data()); /* update total abundances */ ftotal_cl(state_.total_cl.data(), x_rdata_.data()); + return resettedStateIdxs; } void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) { @@ -1178,7 +1180,9 @@ void Model::fsx0(AmiVectorArray &sx, const AmiVector &x) { } } -void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) { +void Model::fsx0_fixedParameters(AmiVectorArray &sx, + const AmiVector &x, + const std::set& resettedStateIdxs) { if (!getReinitializeFixedParameterInitialStates()) return; realtype *stcl = nullptr; @@ -1186,10 +1190,10 @@ void Model::fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x) { if (ncl() > 0) stcl = &state_.stotal_cl.at(plist(ip) * ncl()); fsx_rdata(sx_rdata_.data(), sx.data(ip), stcl, plist(ip)); - fsx0_fixedParameters(sx_rdata_.data(), tstart_, x.data(), - state_.unscaledParameters.data(), + fsx0_fixedParameters(sx_rdata_.data(), tstart_, + x.data(), state_.unscaledParameters.data(), state_.fixedParameters.data(), - plist(ip)); + plist(ip), resettedStateIdxs); fsx_solver(sx.data(ip), sx_rdata_.data()); fstotal_cl(stcl, sx_rdata_.data(), plist(ip)); } @@ -1855,7 +1859,7 @@ void Model::fw(const realtype t, const realtype *x) { void Model::fdwdp(const realtype t, const realtype *x) { if (!nw) return; - + fw(t, x); dwdp_.zero(); if (pythonGenerated) { @@ -1896,7 +1900,7 @@ void Model::fdwdx(const realtype t, const realtype *x) { return; fw(t, x); - + dwdx_.zero(); if (pythonGenerated) { if (!dwdx_hierarchical_.at(0).capacity()) @@ -1939,7 +1943,7 @@ void Model::fdwdw(const realtype t, const realtype *x) { fdwdw(dwdw_.data(), t, x, state_.unscaledParameters.data(), state_.fixedParameters.data(), state_.h.data(), w_.data(), state_.total_cl.data()); - + if (always_check_finite_) { app->checkFinite(gsl::make_span(dwdw_.get()), "dwdw"); } diff --git a/src/model_header.ODE_template.h b/src/model_header.ODE_template.h index 940766858a..ade7170c88 100644 --- a/src/model_header.ODE_template.h +++ b/src/model_header.ODE_template.h @@ -2,6 +2,7 @@ #define _amici_TPL_MODELNAME_h #include #include +#include #include "amici/model_ode.h" #include "amici/solver_cvodes.h" @@ -69,7 +70,7 @@ extern void sigmay_TPL_MODELNAME(realtype *sigmay, const realtype t, TPL_W_DEF extern void x0_TPL_MODELNAME(realtype *x0, const realtype t, const realtype *p, const realtype *k); -extern void x0_fixedParameters_TPL_MODELNAME(realtype *x0, const realtype t, +extern std::set x0_fixedParameters_TPL_MODELNAME(realtype *x0, const realtype t, const realtype *p, const realtype *k); extern void sx0_TPL_MODELNAME(realtype *sx0, const realtype t, @@ -78,7 +79,8 @@ extern void sx0_TPL_MODELNAME(realtype *sx0, const realtype t, extern void sx0_fixedParameters_TPL_MODELNAME(realtype *sx0, const realtype t, const realtype *x0, const realtype *p, - const realtype *k, const int ip); + const realtype *k, const int ip, + const std::set &resettedParameterIdxs); extern void xdot_TPL_MODELNAME(realtype *xdot, const realtype t, const realtype *x, const realtype *p, const realtype *k, const realtype *h, @@ -572,8 +574,9 @@ class Model_TPL_MODELNAME : public amici::Model_ODE { virtual void fsx0_fixedParameters(realtype *sx0, const realtype t, const realtype *x0, const realtype *p, const realtype *k, - const int ip) override { - sx0_fixedParameters_TPL_MODELNAME(sx0, t, x0, p, k, ip); + const int ip, + const std::set &resettedStateIdxs) override { + sx0_fixedParameters_TPL_MODELNAME(sx0, t, x0, p, k, ip, resettedStateIdxs); } /** model specific implementation of fsz @@ -611,10 +614,10 @@ class Model_TPL_MODELNAME : public amici::Model_ODE { * @param p parameter vector * @param k constant vector **/ - virtual void fx0_fixedParameters(realtype *x0, const realtype t, + virtual std::set fx0_fixedParameters(realtype *x0, const realtype t, const realtype *p, const realtype *k) override { - x0_fixedParameters_TPL_MODELNAME(x0, t, p, k); + return x0_fixedParameters_TPL_MODELNAME(x0, t, p, k); } /** model specific implementation for fxdot diff --git a/src/solver.cpp b/src/solver.cpp index 56b7239fc7..529778c399 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -66,7 +66,7 @@ int Solver::run(const realtype tout) const { setStopTime(tout); clock_t starttime = clock(); int status = AMICI_SUCCESS; - + apply_max_num_steps(); if (nx() > 0) { if (getAdjInitDone()) { @@ -83,7 +83,7 @@ int Solver::run(const realtype tout) const { int Solver::step(const realtype tout) const { int status = AMICI_SUCCESS; - + apply_max_num_steps(); if (nx() > 0) { if (getAdjInitDone()) { @@ -99,7 +99,7 @@ int Solver::step(const realtype tout) const { void Solver::runB(const realtype tout) const { clock_t starttime = clock(); - + apply_max_num_steps_B(); if (nx() > 0) { solveB(tout, AMICI_NORMAL); @@ -141,7 +141,7 @@ void Solver::setup(const realtype t0, Model *model, const AmiVector &x0, if (nx() == 0) return; - + initializeLinearSolver(model); initializeNonLinearSolver(); @@ -188,7 +188,7 @@ void Solver::setupB(int *which, const realtype tf, Model *model, if (nx() == 0) return; - + initializeLinearSolverB(model, *which); initializeNonLinearSolverB(*which); @@ -223,11 +223,11 @@ void Solver::setupSteadystate(const realtype t0, Model *model, const AmiVector & } void Solver::updateAndReinitStatesAndSensitivities(Model *model) { - model->fx0_fixedParameters(x_); + auto resettedStateIdxs = model->fx0_fixedParameters(x_); reInit(t_, x_, dx_); if (getSensitivityOrder() >= SensitivityOrder::first) { - model->fsx0_fixedParameters(sx_, x_); + model->fsx0_fixedParameters(sx_, x_, resettedStateIdxs); if (getSensitivityMethod() == SensitivityMethod::forward) sensReInit(sx_, sdx_); } diff --git a/src/steadystateproblem.cpp b/src/steadystateproblem.cpp index 1157085958..87db717dc4 100644 --- a/src/steadystateproblem.cpp +++ b/src/steadystateproblem.cpp @@ -43,7 +43,7 @@ SteadystateProblem::SteadystateProblem(const Solver &solver, const Model &model) void SteadystateProblem::workSteadyStateProblem(Solver *solver, Model *model, int it) { - + /* process solver handling for pre- or postequilibration */ if (it == -1) { /* solver was not run before, set up everything */ @@ -56,8 +56,8 @@ void SteadystateProblem::workSteadyStateProblem(Solver *solver, Model *model, /* solver was run before, extract current state from solver */ solver->writeSolution(&t_, x_, dx_, sx_, xQ_); } - - /* create a Newton solver obejct */ + + /* create a Newton solver object */ auto newtonSolver = NewtonSolver::getSolver(&t_, &x_, *solver, model); /* Compute steady state and get the computation time */ @@ -76,7 +76,7 @@ void SteadystateProblem::workSteadyStateProblem(Solver *solver, Model *model, /* No steady state could be inferred. Store simulation state */ storeSimulationState(model, solver->getSensitivityOrder() >= SensitivityOrder::first); - throw AmiException("Steady state sensitvitiy computation failed due " + throw AmiException("Steady state sensitivity computation failed due " "to unsuccessful factorization of RHS Jacobian"); } } @@ -480,7 +480,7 @@ void SteadystateProblem::applyNewtonsMethod(Model *model, int ix = 0; double gamma = 1.0; bool compNewStep = true; - + if (model->nx_solver == 0) return; diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index f799c28080..8b406a6a29 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -16,7 +16,6 @@ from amici.petab_import import import_petab_problem, PysbPetabProblem from amici.petab_objective import ( simulate_petab, rdatas_to_measurement_df, create_parameterized_edatas) -from amici import SteadyStateSensitivityMode_simulationFSA logger = get_logger(__name__, logging.DEBUG) set_log_level(get_logger("amici.petab_import"), logging.DEBUG) @@ -127,12 +126,12 @@ def check_derivatives(problem: petab.Problem, model: amici.Model) -> None: problem_parameters = {t.Index: getattr(t, petab.NOMINAL_VALUE) for t in problem.parameter_df.itertuples()} solver = model.getSolver() - solver.setSensitivityMethod(amici.SensitivityMethod_forward) - solver.setSensitivityOrder(amici.SensitivityOrder_first) + solver.setSensitivityMethod(amici.SensitivityMethod.forward) + solver.setSensitivityOrder(amici.SensitivityOrder.first) # Required for case 9 to not fail in # amici::NewtonSolver::computeNewtonSensis model.setSteadyStateSensitivityMode( - SteadyStateSensitivityMode_simulationFSA) + amici.SteadyStateSensitivityMode.simulationFSA) def assert_true(x): assert x