diff --git a/cocofest/examples/getting_started/pulse_duration_optimization_nmpc_cyclic.py b/cocofest/examples/getting_started/pulse_duration_optimization_nmpc_cyclic.py index d344422d..1b122aff 100644 --- a/cocofest/examples/getting_started/pulse_duration_optimization_nmpc_cyclic.py +++ b/cocofest/examples/getting_started/pulse_duration_optimization_nmpc_cyclic.py @@ -20,8 +20,10 @@ # --- Build nmpc cyclic --- # n_total_cycles = 8 minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0 +fes_model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10) +fes_model.alpha_a = -4.0 * 10e-1 # Increasing the fatigue rate to make the fatigue more visible nmpc = OcpFesNmpcCyclic( - model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + model=fes_model, n_stim=30, n_shooting=5, final_time=1, diff --git a/cocofest/models/ding2003.py b/cocofest/models/ding2003.py index c4fa23e9..fd694434 100644 --- a/cocofest/models/ding2003.py +++ b/cocofest/models/ding2003.py @@ -34,7 +34,6 @@ def __init__( model_name: str = "ding2003", muscle_name: str = None, sum_stim_truncation: int = None, - stim_prev: list[float] = None, ): super().__init__() self._model_name = model_name @@ -52,8 +51,6 @@ def __init__( self.tau2 = 0.060 # Close value from Ding's experimentation [2] (s) self.km_rest = 0.103 # Value from Ding's experimentation [1] (unitless) - self.stim_prev = stim_prev - def set_a_rest(self, model, a_rest: MX | float): # models is required for bioptim compatibility self.a_rest = a_rest @@ -148,8 +145,6 @@ def system_dynamics( ------- The value of the derivative of each state dx/dt at the current time t """ - if self.stim_prev: - t_stim_prev = self.stim_prev + t_stim_prev r0 = self.km_rest + self.r0_km_relationship # Simplification cn_dot = self.cn_dot_fun(cn, r0, t, t_stim_prev=t_stim_prev) # Equation n°1 f_dot = self.f_dot_fun( diff --git a/cocofest/models/ding2007.py b/cocofest/models/ding2007.py index 1886d66a..4e06a2ee 100644 --- a/cocofest/models/ding2007.py +++ b/cocofest/models/ding2007.py @@ -32,10 +32,9 @@ def __init__( model_name: str = "ding_2007", muscle_name: str = None, sum_stim_truncation: int = None, - stim_prev: list[float] = None, ): super(DingModelPulseDurationFrequency, self).__init__( - model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev + model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation ) self._with_fatigue = False self.impulse_time = None diff --git a/cocofest/models/ding2007_with_fatigue.py b/cocofest/models/ding2007_with_fatigue.py index ec072e49..4a288d5c 100644 --- a/cocofest/models/ding2007_with_fatigue.py +++ b/cocofest/models/ding2007_with_fatigue.py @@ -31,10 +31,9 @@ def __init__( model_name: str = "ding_2007_with_fatigue", muscle_name: str = None, sum_stim_truncation: int = None, - stim_prev: list[float] = None, ): super(DingModelPulseDurationFrequencyWithFatigue, self).__init__( - model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev + model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation ) self._with_fatigue = True @@ -141,8 +140,6 @@ def system_dynamics( ------- The value of the derivative of each state dx/dt at the current time t """ - if self.stim_prev: - t_stim_prev = self.stim_prev + t_stim_prev r0 = km + self.r0_km_relationship # Simplification cn_dot = self.cn_dot_fun(cn, r0, t, t_stim_prev=t_stim_prev) # Equation n°1 from Ding's 2003 article a_calculated = self.a_calculation(a_scale=a, impulse_time=impulse_time) # Equation n°3 from Ding's 2007 article diff --git a/cocofest/optimization/fes_ocp_nmpc_cyclic.py b/cocofest/optimization/fes_ocp_nmpc_cyclic.py index 60e23e4d..f3502ddd 100644 --- a/cocofest/optimization/fes_ocp_nmpc_cyclic.py +++ b/cocofest/optimization/fes_ocp_nmpc_cyclic.py @@ -339,21 +339,21 @@ def _set_objective( return objective_functions def _nmpc_sanity_check(self): - if self.n_total_cycles is None: - raise ValueError("n_total_cycles must be set") - if self.n_simultaneous_cycles is None: - raise ValueError("n_simultaneous_cycles must be set") - if self.n_cycle_to_advance is None: - raise ValueError("n_cycle_to_advance must be set") - if self.cycle_to_keep is None: - raise ValueError("cycle_to_keep must be set") - - if self.n_total_cycles % self.n_cycle_to_advance != 0: - raise ValueError("The number of n_total_cycles must be a multiple of the number n_cycle_to_advance") + if not isinstance(self.n_total_cycles, int): + raise TypeError("n_total_cycles must be an integer") + if not isinstance(self.n_simultaneous_cycles, int): + raise TypeError("n_simultaneous_cycles must be an integer") + if not isinstance(self.n_cycle_to_advance, int): + raise TypeError("n_cycle_to_advance must be an integer") + if not isinstance(self.cycle_to_keep, str): + raise TypeError("cycle_to_keep must be a string") if self.n_cycle_to_advance > self.n_simultaneous_cycles: raise ValueError("The number of n_simultaneous_cycles must be higher than the number of n_cycle_to_advance") + if self.n_total_cycles % self.n_cycle_to_advance != 0: + raise ValueError("The number of n_total_cycles must be a multiple of the number n_cycle_to_advance") + if self.cycle_to_keep not in ["first", "middle", "last"]: raise ValueError("cycle_to_keep must be either 'first', 'middle' or 'last'") if self.cycle_to_keep != "middle": diff --git a/tests/shard1/test_nmpc_cyclic.py b/tests/shard1/test_nmpc_cyclic.py new file mode 100644 index 00000000..a63b3ff4 --- /dev/null +++ b/tests/shard1/test_nmpc_cyclic.py @@ -0,0 +1,256 @@ +import pytest +import re +import numpy as np + +from bioptim import OdeSolver +from cocofest import OcpFesNmpcCyclic, DingModelPulseDurationFrequencyWithFatigue + + +def test_nmpc_cyclic(): + # --- Build target force --- # + target_time = np.linspace(0, 1, 100) + target_force = abs(np.sin(target_time * np.pi)) * 50 + force_tracking = [target_time, target_force] + + # --- Build nmpc cyclic --- # + n_total_cycles = 6 + n_stim = 10 + n_shooting = 5 + + minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0 + fes_model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10) + fes_model.alpha_a = -4.0 * 10e-1 # Increasing the fatigue rate to make the fatigue more visible + + nmpc = OcpFesNmpcCyclic( + model=fes_model, + n_stim=n_stim, + n_shooting=n_shooting, + final_time=1, + pulse_duration={ + "min": minimum_pulse_duration, + "max": 0.0006, + "bimapping": False, + }, + objective={"force_tracking": force_tracking}, + n_total_cycles=n_total_cycles, + n_simultaneous_cycles=3, + n_cycle_to_advance=1, + cycle_to_keep="middle", + use_sx=True, + ode_solver=OdeSolver.COLLOCATION(), + ) + + nmpc.prepare_nmpc() + nmpc.solve() + + # --- Show results --- # + time = [j for sub in nmpc.result["time"] for j in sub] + fatigue = [j for sub in nmpc.result["states"]["A"] for j in sub] + force = [j for sub in nmpc.result["states"]["F"] for j in sub] + + np.testing.assert_almost_equal(len(time), n_total_cycles*n_stim*n_shooting*(nmpc.ode_solver.polynomial_degree+1)) + np.testing.assert_almost_equal(len(fatigue), len(time)) + np.testing.assert_almost_equal(len(force), len(time)) + + np.testing.assert_almost_equal(time[0], 0.0) + np.testing.assert_almost_equal(fatigue[0], 4796.3120362970285) + np.testing.assert_almost_equal(force[0], 3.0948778396159535) + + np.testing.assert_almost_equal(time[750], 3.0000000000000013) + np.testing.assert_almost_equal(fatigue[750], 4427.259641834449) + np.testing.assert_almost_equal(force[750], 4.508999252965375) + + np.testing.assert_almost_equal(time[-1], 5.998611363115943) + np.testing.assert_almost_equal(fatigue[-1], 4063.8504572735123) + np.testing.assert_almost_equal(force[-1], 5.661514731665669) + + +def test_all_nmpc_errors(): + with pytest.raises( + TypeError, + match=re.escape( + "n_total_cycles must be an integer" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=None, + ) + + with pytest.raises( + TypeError, + match=re.escape( + "n_simultaneous_cycles must be an integer" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + ) + + with pytest.raises( + TypeError, + match=re.escape( + "n_cycle_to_advance must be an integer" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + ) + + with pytest.raises( + TypeError, + match=re.escape( + "cycle_to_keep must be a string" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + n_cycle_to_advance=1, + ) + + with pytest.raises( + ValueError, + match=re.escape( + "The number of n_simultaneous_cycles must be higher than the number of n_cycle_to_advance" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + n_cycle_to_advance=6, + cycle_to_keep="middle", + ) + + with pytest.raises( + ValueError, + match=re.escape( + "The number of n_total_cycles must be a multiple of the number n_cycle_to_advance" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + n_cycle_to_advance=2, + cycle_to_keep="middle", + ) + + with pytest.raises( + ValueError, + match=re.escape( + "cycle_to_keep must be either 'first', 'middle' or 'last'" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + n_cycle_to_advance=1, + cycle_to_keep="between", + ) + + with pytest.raises( + NotImplementedError, + match=re.escape( + "Only 'middle' cycle_to_keep is implemented" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=3, + n_cycle_to_advance=1, + cycle_to_keep="first", + ) + + with pytest.raises( + NotImplementedError, + match=re.escape( + "Only 3 simultaneous cycles are implemented yet work in progress" + ), + ): + OcpFesNmpcCyclic( + model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10), + n_stim=10, + n_shooting=5, + final_time=1, + pulse_duration={ + "min": 0.0003, + "max": 0.0006, + "bimapping": False, + }, + n_total_cycles=5, + n_simultaneous_cycles=6, + n_cycle_to_advance=1, + cycle_to_keep="middle", + )