Skip to content

Commit

Permalink
Merge pull request #66 from Kev1CO/NMPC
Browse files Browse the repository at this point in the history
Nmpc cyclic exemple
  • Loading branch information
Kev1CO authored Aug 29, 2024
2 parents 2046931 + 07094e6 commit f50d932
Show file tree
Hide file tree
Showing 11 changed files with 728 additions and 33 deletions.
2 changes: 2 additions & 0 deletions cocofest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .custom_objectives import CustomObjective
from .custom_constraints import CustomConstraint
from .models.fes_model import FesModel
from .models.ding2003 import DingModelFrequency
from .models.ding2003_with_fatigue import DingModelFrequencyWithFatigue
from .models.ding2007 import DingModelPulseDurationFrequency
Expand All @@ -10,6 +11,7 @@
from .optimization.fes_ocp import OcpFes
from .optimization.fes_identification_ocp import OcpFesId
from .optimization.fes_ocp_dynamics import OcpFesMsk
from .optimization.fes_ocp_nmpc_cyclic import OcpFesNmpcCyclic
from .integration.ivp_fes import IvpFes
from .fourier_approx import FourierSeries
from .identification.ding2003_force_parameter_identification import DingModelFrequencyForceParameterIdentification
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cocofest import PlotCyclingResult

# Plot the results
PlotCyclingResult("cycling_fes_driven_min_residual_torque_and_fatigue_results.pkl").plot(starting_location="E")
PlotCyclingResult("cycling_fes_driven_min_residual_torque.pkl").plot(starting_location="E")
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
)

fig, axs = plt.subplots(1, 1, figsize=(3, (1 / 3) * 7))
fig.suptitle("Muscle fatigue", fontsize=20, fontweight="bold", fontname="Times New Roman")
fig.suptitle("Muscle fatigue", fontsize=20, fontweight="bold")

axs.set_xlim(left=0, right=1.5)
plt.setp(
Expand All @@ -192,8 +192,8 @@
a_force_sum_percentage = (np.array(a_force_sum_list) / a_sum_base_line) * 100
a_fatigue_sum_percentage = (np.array(a_fatigue_sum_list) / a_sum_base_line) * 100

axs.plot(data_minimize_force["time"], a_force_sum_percentage, lw=5)
axs.plot(data_minimize_force["time"], a_fatigue_sum_percentage, lw=5)
axs.plot(data_minimize_force["time"], a_force_sum_percentage, lw=5, label="Minimize force production")
axs.plot(data_minimize_force["time"], a_fatigue_sum_percentage, lw=5, label="Maximize muscle capacity")

axs.set_xlim(left=0, right=1.5)

Expand All @@ -204,21 +204,16 @@
)

labels = axs.get_xticklabels() + axs.get_yticklabels()
[label.set_fontname("Times New Roman") for label in labels]
[label.set_fontsize(14) for label in labels]
fig.text(
0.05,
0.5,
"Fatigue percentage (%)",
"Muscle capacity (%)",
ha="center",
va="center",
rotation="vertical",
fontsize=18,
weight="bold",
font="Times New Roman",
)
axs.text(0.75, 96.3, "Time (s)", ha="center", va="center", fontsize=18, weight="bold", font="Times New Roman")
plt.legend(
["Force", "Fatigue"], loc="upper right", ncol=1, prop={"family": "Times New Roman", "size": 14, "weight": "bold"}
)
axs.text(0.75, 96.3, "Time (s)", ha="center", va="center", fontsize=18, weight="bold")
axs.legend(title="Cost function", fontsize="medium", loc="upper right", ncol=1)
plt.show()
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
This example showcases a moving time horizon simulation problem of cyclic muscle force tracking.
The FES model used here is Ding's 2007 pulse duration and frequency model with fatigue.
Only the pulse duration is optimized, frequency is fixed.
The nmpc cyclic problem is composed of 3 cycles and will move forward 1 cycle at each step.
Only the middle cycle is kept in the optimization problem, the nmpc cyclic problem stops once the last 6th cycle is reached.
"""

import numpy as np
import matplotlib.pyplot as plt

from bioptim import OdeSolver
from cocofest import OcpFesNmpcCyclic, DingModelPulseDurationFrequencyWithFatigue

# --- Build target force --- #
target_time = np.linspace(0, 1, 100)
target_force = abs(np.sin(target_time * np.pi)) * 200
force_tracking = [target_time, target_force]

# --- Build nmpc cyclic --- #
n_total_cycles = 8
minimum_pulse_duration = DingModelPulseDurationFrequencyWithFatigue().pd0
fes_model = DingModelPulseDurationFrequencyWithFatigue(sum_stim_truncation=10)
fes_model.alpha_a = -4.0 * 10e-1 # Increasing the fatigue rate to make the fatigue more visible
nmpc = OcpFesNmpcCyclic(
model=fes_model,
n_stim=30,
n_shooting=5,
final_time=1,
pulse_duration={
"min": minimum_pulse_duration,
"max": 0.0006,
"bimapping": False,
},
objective={"force_tracking": force_tracking},
n_total_cycles=n_total_cycles,
n_simultaneous_cycles=3,
n_cycle_to_advance=1,
cycle_to_keep="middle",
use_sx=True,
ode_solver=OdeSolver.COLLOCATION(),
)

nmpc.prepare_nmpc()
nmpc.solve()

# --- Show results --- #
time = [j for sub in nmpc.result["time"] for j in sub]
fatigue = [j for sub in nmpc.result["states"]["A"] for j in sub]
force = [j for sub in nmpc.result["states"]["F"] for j in sub]

ax1 = plt.subplot(221)
ax1.plot(time, fatigue, label="A", color="green")
ax1.set_title("Fatigue", weight="bold")
ax1.set_xlabel("Time (s)")
ax1.set_ylabel("Force scaling factor (-)")
plt.legend()

ax2 = plt.subplot(222)
ax2.plot(time, force, label="F", color="red", linewidth=4)
for i in range(n_total_cycles):
if i == 0:
ax2.plot(target_time, target_force, label="Target", color="purple")
else:
ax2.plot(target_time + i, target_force, color="purple")
ax2.set_title("Force", weight="bold")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Force (N)")
plt.legend()

barWidth = 0.25 # set width of bar
cycles = nmpc.result["parameters"]["pulse_duration"] # set height of bar
bar = [] # Set position of bar on X axis
for i in range(n_total_cycles):
if i == 0:
br = [barWidth * (x + 1) for x in range(len(cycles[i]))]
else:
br = [bar[-1][-1] + barWidth * (x + 1) for x in range(len(cycles[i]))]
bar.append(br)

ax3 = plt.subplot(212)
for i in range(n_total_cycles):
ax3.bar(bar[i], cycles[i], width=barWidth, edgecolor="grey", label=f"cycle n°{i+1}")
ax3.set_xticks([np.mean(r) for r in bar], [str(i + 1) for i in range(n_total_cycles)])
ax3.set_xlabel("Cycles")
ax3.set_ylabel("Pulse duration (s)")
plt.legend()
ax3.set_title("Pulse duration", weight="bold")
plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,18 @@
ticks=[1e-12, 1e-10, 1e-8, 1e-6, 1e-4, 1e-2, 1, max_error],
cmap=cmap,
)
cbar1.set_label(label="Force absolute error (N)", size=25, fontname="Times New Roman")
cbar1.set_label(label="Muscle force absolute error (N)", size=25, fontname="Times New Roman")

cbar1.ax.set_yticklabels(
[
"{:.0e}".format(float(1e-12)),
"{:.0e}".format(float(1e-10)),
"{:.0e}".format(float(1e-8)),
"{:.0e}".format(float(1e-6)),
"{:.0e}".format(float(1e-4)),
"{:.0e}".format(float(1e-2)),
"{:.0e}".format(float(1)),
"{:.1e}".format(float(round(max_error))),
"1e⁻¹²",
"1e⁻¹⁰",
"1e⁻⁸",
"1e⁻⁶",
"1e⁻⁴",
"1e⁻²",
"1e⁰",
"5.3e⁺¹",
],
size=25,
fontname="Times New Roman",
Expand All @@ -184,14 +184,14 @@
y_beneath_1e_8 = []
for j in range(len((all_mode_list_error_beneath_1e_8[i]))):
y_beneath_1e_8.append(parameter_list[i][all_mode_list_error_beneath_1e_8[i][j]][1])
axs.plot(x_beneath_1e_8, y_beneath_1e_8, color="darkred", label="Calcium absolute error < 1e-8", linewidth=3)
axs.plot(x_beneath_1e_8, y_beneath_1e_8, color="darkred", label=r"Calcium absolute error < 1e⁻⁸", linewidth=3)

axs.scatter(0, 0, color="white", label="OCP (s) | 100 Integrations (s)", marker="+", s=0, lw=0)
axs.scatter(0, 0, color="white", label="Initialization (s) | 100 Integrations (s)", marker="+", s=0, lw=0)
axs.scatter(
1,
1,
color="blue",
label=" " + str(round(a_ocp_time, 3)) + " " + str(round(a_integration_time, 3)),
label=" " + str(round(a_ocp_time, 3)) + " " + str(round(a_integration_time, 3)),
marker="^",
s=200,
lw=5,
Expand All @@ -200,7 +200,7 @@
100,
39,
color="black",
label=" " + str(round(b_ocp_time, 3)) + " " + str(round(b_integration_time, 3)),
label=" " + str(round(b_ocp_time, 3)) + " " + str(round(b_integration_time, 3)),
marker="+",
s=500,
lw=5,
Expand All @@ -209,15 +209,15 @@
100,
100,
color="green",
label=" " + str(round(c_ocp_time, 3)) + " " + str(round(c_integration_time, 3)),
label=" " + str(round(c_ocp_time, 3)) + " " + str(round(c_integration_time, 3)),
marker=",",
s=200,
lw=5,
)

axs.set_xlabel("Frequency (Hz)", fontsize=25, fontname="Times New Roman")
axs.xaxis.set_major_locator(MaxNLocator(integer=True))
axs.set_ylabel("Past stimulation kept for computation (n)", fontsize=25, fontname="Times New Roman")
axs.set_ylabel("Past stimulations kept for computation (n)", fontsize=25, fontname="Times New Roman")
axs.yaxis.set_major_locator(MaxNLocator(integer=True))

ticks = np.arange(1, 101, 1).tolist()
Expand Down
7 changes: 6 additions & 1 deletion cocofest/models/ding2007.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class DingModelPulseDurationFrequency(DingModelFrequency):
Muscle & Nerve: Official Journal of the American Association of Electrodiagnostic Medicine, 36(2), 214-222.
"""

def __init__(self, model_name: str = "ding_2007", muscle_name: str = None, sum_stim_truncation: int = None):
def __init__(
self,
model_name: str = "ding_2007",
muscle_name: str = None,
sum_stim_truncation: int = None,
):
super(DingModelPulseDurationFrequency, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation
)
Expand Down
5 changes: 4 additions & 1 deletion cocofest/models/ding2007_with_fatigue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ class DingModelPulseDurationFrequencyWithFatigue(DingModelPulseDurationFrequency
"""

def __init__(
self, model_name: str = "ding_2007_with_fatigue", muscle_name: str = None, sum_stim_truncation: int = None
self,
model_name: str = "ding_2007_with_fatigue",
muscle_name: str = None,
sum_stim_truncation: int = None,
):
super(DingModelPulseDurationFrequencyWithFatigue, self).__init__(
model_name=model_name, muscle_name=muscle_name, sum_stim_truncation=sum_stim_truncation
Expand Down
5 changes: 1 addition & 4 deletions cocofest/optimization/fes_ocp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np

from bioptim import (
BiMapping,
# BiMappingList, parameter mapping not yet implemented
BoundsList,
ConstraintList,
ControlType,
Expand All @@ -29,7 +27,6 @@
from ..models.ding2007 import DingModelPulseDurationFrequency
from ..models.ding2007_with_fatigue import DingModelPulseDurationFrequencyWithFatigue
from ..models.ding2003 import DingModelFrequency
from ..models.ding2003_with_fatigue import DingModelFrequencyWithFatigue
from ..models.hmed2018 import DingModelIntensityFrequency
from ..models.hmed2018_with_fatigue import DingModelIntensityFrequencyWithFatigue

Expand Down Expand Up @@ -158,7 +155,7 @@ def prepare_ocp(
force_fourier_coefficient = (
None if force_tracking is None else OcpFes._build_fourier_coefficient(force_tracking)
)
end_node_tracking = end_node_tracking

models = [model] * n_stim
n_shooting = [n_shooting] * n_stim

Expand Down
Loading

0 comments on commit f50d932

Please sign in to comment.