Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Ipuch committed Jul 21, 2023
1 parent 29f6038 commit 7784268
Showing 1 changed file with 104 additions and 55 deletions.
159 changes: 104 additions & 55 deletions bioptim/dynamics/configure_new_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

class NewVariableConfiguration:
def __init__(
self,
name: str,
name_elements: list,
ocp,
nlp,
as_states: bool,
as_controls: bool,
as_states_dot: bool = False,
as_stochastic: bool = False,
fatigue: FatigueList = None,
combine_name: str = None,
combine_state_control_plot: bool = False,
skip_plot: bool = False,
axes_idx: BiMapping = None,
self,
name: str,
name_elements: list,
ocp,
nlp,
as_states: bool,
as_controls: bool,
as_states_dot: bool = False,
as_stochastic: bool = False,
fatigue: FatigueList = None,
combine_name: str = None,
combine_state_control_plot: bool = False,
skip_plot: bool = False,
axes_idx: BiMapping = None,
):
"""
Add a new variable to the states/controls pool
Expand Down Expand Up @@ -90,9 +90,7 @@ def __init__(

self._check_combine_state_control_plot()

if _manage_fatigue_to_new_variable(
name, name_elements, ocp, nlp, as_states, as_controls, fatigue
):
if _manage_fatigue_to_new_variable(name, name_elements, ocp, nlp, as_states, as_controls, fatigue):
# If the element is fatigable, this function calls back configure_new_variable to fill everything.
# Therefore, we can exit now
return
Expand All @@ -114,14 +112,13 @@ def __init__(
self._declare_legend()
self._declare_cx_and_plot()


def _check_combine_state_control_plot(self):
""" Check if combine_state_control_plot and combine_name are defined simultaneously """
"""Check if combine_state_control_plot and combine_name are defined simultaneously"""
if self.combine_state_control_plot and self.combine_name is not None:
raise ValueError("combine_name and combine_state_control_plot cannot be defined simultaneously")

Check warning on line 118 in bioptim/dynamics/configure_new_variable.py

View check run for this annotation

Codecov / codecov/patch

bioptim/dynamics/configure_new_variable.py#L118

Added line #L118 was not covered by tests

def _check_phase_mapping_of_variable(self):
""" Check if the use of states from another phases is compatible with assume_phase_dynamics """
"""Check if the use of states from another phases is compatible with assume_phase_dynamics"""
if not self.ocp.assume_phase_dynamics and (
self.nlp.use_states_from_phase_idx != self.nlp.phase_idx
or self.nlp.use_states_dot_from_phase_idx != self.nlp.phase_idx
Expand All @@ -131,21 +128,21 @@ def _check_phase_mapping_of_variable(self):
raise ValueError("map_state=True must be used alongside with assume_phase_dynamics=True")

Check warning on line 128 in bioptim/dynamics/configure_new_variable.py

View check run for this annotation

Codecov / codecov/patch

bioptim/dynamics/configure_new_variable.py#L128

Added line #L128 was not covered by tests

def _declare_phase_copy_booleans(self):
""" Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True """
"""Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True"""
self.copy_states = (
self.nlp.use_states_from_phase_idx is not None
and self.nlp.use_states_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0]
self.nlp.use_states_from_phase_idx is not None
and self.nlp.use_states_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0]
)
self.copy_controls = (
self.nlp.use_controls_from_phase_idx is not None
and self.nlp.use_controls_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0]
self.nlp.use_controls_from_phase_idx is not None
and self.nlp.use_controls_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0]
)
self.copy_states_dot = (
self.nlp.use_states_dot_from_phase_idx is not None
and self.nlp.use_states_dot_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0]
self.nlp.use_states_dot_from_phase_idx is not None
and self.nlp.use_states_dot_from_phase_idx < self.nlp.phase_idx
and self.name in self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0]
)

def define_cx_scaled(self, n_col: int, n_shooting: int, initial_node) -> list:
Expand Down Expand Up @@ -177,36 +174,60 @@ def define_cx_unscaled(self, _cx_scaled: list, scaling: np.ndarray) -> list:
return _cx

def _declare_auto_variable_mapping(self):
""" Declare the mapping of the new variable if not already declared """
"""Declare the mapping of the new variable if not already declared"""
if self.name not in self.nlp.variable_mappings:
self.nlp.variable_mappings[self.name] = BiMapping(range(len(self.name_elements)), range(len(self.name_elements)))
self.nlp.variable_mappings[self.name] = BiMapping(
range(len(self.name_elements)), range(len(self.name_elements))
)

def _declare_initial_guess(self):
if self.as_states and self.name not in self.nlp.x_init:
self.nlp.x_init.add(self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx)))
self.nlp.x_init.add(
self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx))
)
if self.as_controls and self.name not in self.nlp.u_init:
self.nlp.u_init.add(self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx)))
self.nlp.u_init.add(
self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx))
)

def _declare_variable_scaling(self):
if self.as_states and self.name not in self.nlp.x_scaling:
self.nlp.x_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)))
self.nlp.x_scaling.add(
self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))
)
if self.as_states_dot and self.name not in self.nlp.xdot_scaling:
self.nlp.xdot_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)))
self.nlp.xdot_scaling.add(
self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))
)
if self.as_controls and self.name not in self.nlp.u_scaling:
self.nlp.u_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)))
self.nlp.u_scaling.add(
self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))
)

def _use_copy(self):
""" Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True """
self.mx_states = [] if not self.copy_states else [self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0][self.name].mx]
"""Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True"""
self.mx_states = (
[] if not self.copy_states else [self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0][self.name].mx]
)
self.mx_states_dot = (
[] if not self.copy_states_dot else [self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0][self.name].mx]
[]
if not self.copy_states_dot
else [self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0][self.name].mx]
)
self.mx_controls = (
[]
if not self.copy_controls
else [self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0][self.name].mx]
)
self.mx_controls = [] if not self.copy_controls else [self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0][self.name].mx]
self.mx_stochastic = []

# todo: if mapping on variables, what do we do with mapping on the nodes
for i in self.nlp.variable_mappings[self.name].to_second.map_idx:
var_name = f"{'-' if np.sign(i) < 0 else ''}{self.name}_{self.name_elements[abs(i)]}_MX" if i is not None else "zero"
var_name = (
f"{'-' if np.sign(i) < 0 else ''}{self.name}_{self.name_elements[abs(i)]}_MX"
if i is not None
else "zero"
)

if not self.copy_states:
self.mx_states.append(MX.sym(var_name, 1, 1))
Expand All @@ -225,12 +246,12 @@ def _use_copy(self):
self.mx_stochastic = vertcat(*self.mx_stochastic)

def _declare_auto_axes_idx(self):
""" Declare the axes index if not already declared """
"""Declare the axes index if not already declared"""
if not self.axes_idx:
self.axes_idx = BiMapping(to_first=range(len(self.name_elements)), to_second=range(len(self.name_elements)))

def _declare_legend(self):
""" Declare the legend if not already declared """
"""Declare the legend if not already declared"""
self.legend = []
for idx, name_el in enumerate(self.name_elements):
if idx is not None and idx in self.axes_idx.to_first.map_idx:
Expand All @@ -245,7 +266,11 @@ def _declare_legend(self):
def _declare_cx_and_plot(self):
if self.as_states:
for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1):
n_cx = self.nlp.ode_solver.polynomial_degree + 2 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3
n_cx = (
self.nlp.ode_solver.polynomial_degree + 2
if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION)
else 3
)
cx_scaled = (
self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[node_index][self.name].original_cx
if self.copy_states
Expand All @@ -256,7 +281,9 @@ def _declare_cx_and_plot(self):
if self.copy_states
else self.define_cx_unscaled(cx_scaled, self.nlp.x_scaling[self.name].scaling)
)
self.nlp.states.append(self.name, cx[0], cx_scaled[0], self.mx_states, self.nlp.variable_mappings[self.name], node_index)
self.nlp.states.append(
self.name, cx[0], cx_scaled[0], self.mx_states, self.nlp.variable_mappings[self.name], node_index
)
if not self.skip_plot:
self.nlp.plot[f"{self.name}_states"] = CustomPlot(
lambda t, x, u, p, s: x[self.nlp.states[self.name].index, :],
Expand Down Expand Up @@ -284,7 +311,9 @@ def _declare_cx_and_plot(self):
if self.copy_controls
else self.define_cx_unscaled(cx_scaled, self.nlp.u_scaling[self.name].scaling)
)
self.nlp.controls.append(self.name, cx[0], cx_scaled[0], self.mx_controls, self.nlp.variable_mappings[self.name], node_index)
self.nlp.controls.append(
self.name, cx[0], cx_scaled[0], self.mx_controls, self.nlp.variable_mappings[self.name], node_index
)

plot_type = PlotType.PLOT if self.nlp.control_type == ControlType.LINEAR_CONTINUOUS else PlotType.STEP
if not self.skip_plot:
Expand All @@ -293,12 +322,18 @@ def _declare_cx_and_plot(self):
plot_type=plot_type,
axes_idx=self.axes_idx,
legend=self.legend,
combine_to=f"{self.name}_states" if self.as_states and self.combine_state_control_plot else self.combine_name,
combine_to=f"{self.name}_states"
if self.as_states and self.combine_state_control_plot
else self.combine_name,
)

if self.as_states_dot:
for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1):
n_cx = self.nlp.ode_solver.polynomial_degree + 1 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3
n_cx = (
self.nlp.ode_solver.polynomial_degree + 1
if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION)
else 3
)
if n_cx < 3:
n_cx = 3
cx_scaled = (
Expand All @@ -311,16 +346,32 @@ def _declare_cx_and_plot(self):
if self.copy_states_dot
else self.define_cx_unscaled(cx_scaled, self.nlp.xdot_scaling[self.name].scaling)
)
self.nlp.states_dot.append(self.name, cx[0], cx_scaled[0], self.mx_states_dot, self.nlp.variable_mappings[self.name], node_index)
self.nlp.states_dot.append(
self.name,
cx[0],
cx_scaled[0],
self.mx_states_dot,
self.nlp.variable_mappings[self.name],
node_index,
)

if self.as_stochastic:
for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1):
n_cx = self.nlp.ode_solver.polynomial_degree + 1 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3
n_cx = (
self.nlp.ode_solver.polynomial_degree + 1
if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION)
else 3
)
if n_cx < 3:
n_cx = 3

Check warning on line 366 in bioptim/dynamics/configure_new_variable.py

View check run for this annotation

Codecov / codecov/patch

bioptim/dynamics/configure_new_variable.py#L366

Added line #L366 was not covered by tests
cx_scaled = self.define_cx_scaled(n_col=n_cx, n_shooting=1, initial_node=node_index)
self.nlp.stochastic_variables.append(
self.name, cx_scaled[0], cx_scaled[0], self.mx_stochastic, self.nlp.variable_mappings[self.name], node_index
self.name,
cx_scaled[0],
cx_scaled[0],
self.mx_stochastic,
self.nlp.variable_mappings[self.name],
node_index,
)


Expand Down Expand Up @@ -410,9 +461,7 @@ def _manage_fatigue_to_new_variable(
color=color[i],
)
elif i == 0:
NewVariableConfiguration(
f"{name}", name_elements, ocp, nlp, as_states, as_controls, skip_plot=True
)
NewVariableConfiguration(f"{name}", name_elements, ocp, nlp, as_states, as_controls, skip_plot=True)
nlp.plot[f"{name}_controls"] = CustomPlot(
lambda t, x, u, p, s, key: u[nlp.controls[key].index, :],
plot_type=PlotType.STEP,
Expand Down

0 comments on commit 7784268

Please sign in to comment.