Skip to content

Commit

Permalink
Adding tests for nmpc cyclic
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin CO committed Aug 29, 2024
1 parent 6df6b3c commit a67d78c
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
# --- Build nmpc cyclic --- #
n_total_cycles = 8
minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
fes_model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10)
fes_model.alpha_a = -4.0 * 10e-1 # Increasing the fatigue rate to make the fatigue more visible
nmpc = OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
model=fes_model,
n_stim=30,
n_shooting=5,
final_time=1,
Expand Down
5 changes: 0 additions & 5 deletions cocofest/models/ding2003.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
model_name: str = "ding2003",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None,
):
super().__init__()
self._model_name = model_name
Expand All @@ -52,8 +51,6 @@ def __init__(
self.tau2 = 0.060 # Close value from Ding's experimentation [2] (s)
self.km_rest = 0.103 # Value from Ding's experimentation [1] (unitless)

self.stim_prev = stim_prev

def set_a_rest(self, model, a_rest: MX | float):
# models is required for bioptim compatibility
self.a_rest = a_rest
Expand Down Expand Up @@ -148,8 +145,6 @@ def system_dynamics(
-------
The value of the derivative of each state dx/dt at the current time t
"""
if self.stim_prev:
t_stim_prev = self.stim_prev + t_stim_prev
r0 = self.km_rest + self.r0_km_relationship # Simplification
cn_dot = self.cn_dot_fun(cn, r0, t, t_stim_prev=t_stim_prev) # Equation n°1
f_dot = self.f_dot_fun(
Expand Down
3 changes: 1 addition & 2 deletions cocofest/models/ding2007.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def __init__(
model_name: str = "ding_2007",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None,
):
super(DingModelPulseDurationFrequency, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation
)
self._with_fatigue = False
self.impulse_time = None
Expand Down
5 changes: 1 addition & 4 deletions cocofest/models/ding2007_with_fatigue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def __init__(
model_name: str = "ding_2007_with_fatigue",
muscle_name: str = None,
sum_stim_truncation: int = None,
stim_prev: list[float] = None,
):
super(DingModelPulseDurationFrequencyWithFatigue, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation, stim_prev=stim_prev
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation
)
self._with_fatigue = True

Expand Down Expand Up @@ -141,8 +140,6 @@ def system_dynamics(
-------
The value of the derivative of each state dx/dt at the current time t
"""
if self.stim_prev:
t_stim_prev = self.stim_prev + t_stim_prev
r0 = km + self.r0_km_relationship # Simplification
cn_dot = self.cn_dot_fun(cn, r0, t, t_stim_prev=t_stim_prev) # Equation n°1 from Ding's 2003 article
a_calculated = self.a_calculation(a_scale=a, impulse_time=impulse_time) # Equation n°3 from Ding's 2007 article
Expand Down
22 changes: 11 additions & 11 deletions cocofest/optimization/fes_ocp_nmpc_cyclic.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,21 +339,21 @@ def _set_objective(
return objective_functions

def _nmpc_sanity_check(self):
if self.n_total_cycles is None:
raise ValueError("n_total_cycles must be set")
if self.n_simultaneous_cycles is None:
raise ValueError("n_simultaneous_cycles must be set")
if self.n_cycle_to_advance is None:
raise ValueError("n_cycle_to_advance must be set")
if self.cycle_to_keep is None:
raise ValueError("cycle_to_keep must be set")

if self.n_total_cycles % self.n_cycle_to_advance != 0:
raise ValueError("The number of n_total_cycles must be a multiple of the number n_cycle_to_advance")
if not isinstance(self.n_total_cycles, int):
raise TypeError("n_total_cycles must be an integer")
if not isinstance(self.n_simultaneous_cycles, int):
raise TypeError("n_simultaneous_cycles must be an integer")
if not isinstance(self.n_cycle_to_advance, int):
raise TypeError("n_cycle_to_advance must be an integer")
if not isinstance(self.cycle_to_keep, str):
raise TypeError("cycle_to_keep must be a string")

if self.n_cycle_to_advance > self.n_simultaneous_cycles:
raise ValueError("The number of n_simultaneous_cycles must be higher than the number of n_cycle_to_advance")

if self.n_total_cycles % self.n_cycle_to_advance != 0:
raise ValueError("The number of n_total_cycles must be a multiple of the number n_cycle_to_advance")

if self.cycle_to_keep not in ["first", "middle", "last"]:
raise ValueError("cycle_to_keep must be either 'first', 'middle' or 'last'")
if self.cycle_to_keep != "middle":
Expand Down
256 changes: 256 additions & 0 deletions tests/shard1/test_nmpc_cyclic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import pytest
import re
import numpy as np

from bioptim import OdeSolver
from cocofest import OcpFesNmpcCyclic, DingModelPulseDurationFrequencyWithFatigue


def test_nmpc_cyclic():
# --- Build target force --- #
target_time = np.linspace(0, 1, 100)
target_force = abs(np.sin(target_time * np.pi)) * 50
force_tracking = [target_time, target_force]

# --- Build nmpc cyclic --- #
n_total_cycles = 6
n_stim = 10
n_shooting = 5

minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
fes_model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10)
fes_model.alpha_a = -4.0 * 10e-1 # Increasing the fatigue rate to make the fatigue more visible

nmpc = OcpFesNmpcCyclic(
model=fes_model,
n_stim=n_stim,
n_shooting=n_shooting,
final_time=1,
pulse_duration={
"min": minimum_pulse_duration,
"max": 0.0006,
"bimapping": False,
},
objective={"force_tracking": force_tracking},
n_total_cycles=n_total_cycles,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="middle",
use_sx=True,
ode_solver=OdeSolver.COLLOCATION(),
)

nmpc.prepare_nmpc()
nmpc.solve()

# --- Show results --- #
time = [j for sub in nmpc.result["time"] for j in sub]
fatigue = [j for sub in nmpc.result["states"]["A"] for j in sub]
force = [j for sub in nmpc.result["states"]["F"] for j in sub]

np.testing.assert_almost_equal(len(time), n_total_cycles*n_stim*n_shooting*(nmpc.ode_solver.polynomial_degree+1))
np.testing.assert_almost_equal(len(fatigue), len(time))
np.testing.assert_almost_equal(len(force), len(time))

np.testing.assert_almost_equal(time[0], 0.0)
np.testing.assert_almost_equal(fatigue[0], 4796.3120362970285)
np.testing.assert_almost_equal(force[0], 3.0948778396159535)

np.testing.assert_almost_equal(time[750], 3.0000000000000013)
np.testing.assert_almost_equal(fatigue[750], 4427.259641834449)
np.testing.assert_almost_equal(force[750], 4.508999252965375)

np.testing.assert_almost_equal(time[-1], 5.998611363115943)
np.testing.assert_almost_equal(fatigue[-1], 4063.8504572735123)
np.testing.assert_almost_equal(force[-1], 5.661514731665669)


def test_all_nmpc_errors():
with pytest.raises(
TypeError,
match=re.escape(
"n_total_cycles must be an integer"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=None,
)

with pytest.raises(
TypeError,
match=re.escape(
"n_simultaneous_cycles must be an integer"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
)

with pytest.raises(
TypeError,
match=re.escape(
"n_cycle_to_advance must be an integer"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
)

with pytest.raises(
TypeError,
match=re.escape(
"cycle_to_keep must be a string"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
)

with pytest.raises(
ValueError,
match=re.escape(
"The number of n_simultaneous_cycles must be higher than the number of n_cycle_to_advance"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
n_cycle_to_advance=6,
cycle_to_keep="middle",
)

with pytest.raises(
ValueError,
match=re.escape(
"The number of n_total_cycles must be a multiple of the number n_cycle_to_advance"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
n_cycle_to_advance=2,
cycle_to_keep="middle",
)

with pytest.raises(
ValueError,
match=re.escape(
"cycle_to_keep must be either 'first', 'middle' or 'last'"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="between",
)

with pytest.raises(
NotImplementedError,
match=re.escape(
"Only 'middle' cycle_to_keep is implemented"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="first",
)

with pytest.raises(
NotImplementedError,
match=re.escape(
"Only 3 simultaneous cycles are implemented yet work in progress"
),
):
OcpFesNmpcCyclic(
model=DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10),
n_stim=10,
n_shooting=5,
final_time=1,
pulse_duration={
"min": 0.0003,
"max": 0.0006,
"bimapping": False,
},
n_total_cycles=5,
n_simultaneous_cycles=6,
n_cycle_to_advance=1,
cycle_to_keep="middle",
)

0 comments on commit a67d78c

Please sign in to comment.