diff --git a/bioptim/dynamics/configure_new_variable.py b/bioptim/dynamics/configure_new_variable.py index 87344072c..fbfb10d64 100644 --- a/bioptim/dynamics/configure_new_variable.py +++ b/bioptim/dynamics/configure_new_variable.py @@ -17,20 +17,20 @@ class NewVariableConfiguration: def __init__( - self, - name: str, - name_elements: list, - ocp, - nlp, - as_states: bool, - as_controls: bool, - as_states_dot: bool = False, - as_stochastic: bool = False, - fatigue: FatigueList = None, - combine_name: str = None, - combine_state_control_plot: bool = False, - skip_plot: bool = False, - axes_idx: BiMapping = None, + self, + name: str, + name_elements: list, + ocp, + nlp, + as_states: bool, + as_controls: bool, + as_states_dot: bool = False, + as_stochastic: bool = False, + fatigue: FatigueList = None, + combine_name: str = None, + combine_state_control_plot: bool = False, + skip_plot: bool = False, + axes_idx: BiMapping = None, ): """ Add a new variable to the states/controls pool @@ -90,9 +90,7 @@ def __init__( self._check_combine_state_control_plot() - if _manage_fatigue_to_new_variable( - name, name_elements, ocp, nlp, as_states, as_controls, fatigue - ): + if _manage_fatigue_to_new_variable(name, name_elements, ocp, nlp, as_states, as_controls, fatigue): # If the element is fatigable, this function calls back configure_new_variable to fill everything. # Therefore, we can exit now return @@ -114,14 +112,13 @@ def __init__( self._declare_legend() self._declare_cx_and_plot() - def _check_combine_state_control_plot(self): - """ Check if combine_state_control_plot and combine_name are defined simultaneously """ + """Check if combine_state_control_plot and combine_name are defined simultaneously""" if self.combine_state_control_plot and self.combine_name is not None: raise ValueError("combine_name and combine_state_control_plot cannot be defined simultaneously") def _check_phase_mapping_of_variable(self): - """ Check if the use of states from another phases is compatible with assume_phase_dynamics """ + """Check if the use of states from another phases is compatible with assume_phase_dynamics""" if not self.ocp.assume_phase_dynamics and ( self.nlp.use_states_from_phase_idx != self.nlp.phase_idx or self.nlp.use_states_dot_from_phase_idx != self.nlp.phase_idx @@ -131,21 +128,21 @@ def _check_phase_mapping_of_variable(self): raise ValueError("map_state=True must be used alongside with assume_phase_dynamics=True") def _declare_phase_copy_booleans(self): - """ Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True """ + """Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True""" self.copy_states = ( - self.nlp.use_states_from_phase_idx is not None - and self.nlp.use_states_from_phase_idx < self.nlp.phase_idx - and self.name in self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0] + self.nlp.use_states_from_phase_idx is not None + and self.nlp.use_states_from_phase_idx < self.nlp.phase_idx + and self.name in self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0] ) self.copy_controls = ( - self.nlp.use_controls_from_phase_idx is not None - and self.nlp.use_controls_from_phase_idx < self.nlp.phase_idx - and self.name in self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0] + self.nlp.use_controls_from_phase_idx is not None + and self.nlp.use_controls_from_phase_idx < self.nlp.phase_idx + and self.name in self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0] ) self.copy_states_dot = ( - self.nlp.use_states_dot_from_phase_idx is not None - and self.nlp.use_states_dot_from_phase_idx < self.nlp.phase_idx - and self.name in self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0] + self.nlp.use_states_dot_from_phase_idx is not None + and self.nlp.use_states_dot_from_phase_idx < self.nlp.phase_idx + and self.name in self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0] ) def define_cx_scaled(self, n_col: int, n_shooting: int, initial_node) -> list: @@ -177,36 +174,60 @@ def define_cx_unscaled(self, _cx_scaled: list, scaling: np.ndarray) -> list: return _cx def _declare_auto_variable_mapping(self): - """ Declare the mapping of the new variable if not already declared """ + """Declare the mapping of the new variable if not already declared""" if self.name not in self.nlp.variable_mappings: - self.nlp.variable_mappings[self.name] = BiMapping(range(len(self.name_elements)), range(len(self.name_elements))) + self.nlp.variable_mappings[self.name] = BiMapping( + range(len(self.name_elements)), range(len(self.name_elements)) + ) def _declare_initial_guess(self): if self.as_states and self.name not in self.nlp.x_init: - self.nlp.x_init.add(self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx))) + self.nlp.x_init.add( + self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx)) + ) if self.as_controls and self.name not in self.nlp.u_init: - self.nlp.u_init.add(self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx))) + self.nlp.u_init.add( + self.name, initial_guess=np.zeros(len(self.nlp.variable_mappings[self.name].to_first.map_idx)) + ) def _declare_variable_scaling(self): if self.as_states and self.name not in self.nlp.x_scaling: - self.nlp.x_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))) + self.nlp.x_scaling.add( + self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)) + ) if self.as_states_dot and self.name not in self.nlp.xdot_scaling: - self.nlp.xdot_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))) + self.nlp.xdot_scaling.add( + self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)) + ) if self.as_controls and self.name not in self.nlp.u_scaling: - self.nlp.u_scaling.add(self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx))) + self.nlp.u_scaling.add( + self.name, scaling=np.ones(len(self.nlp.variable_mappings[self.name].to_first.map_idx)) + ) def _use_copy(self): - """ Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True """ - self.mx_states = [] if not self.copy_states else [self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0][self.name].mx] + """Use of states[0] and controls[0] is permitted since ocp.assume_phase_dynamics is True""" + self.mx_states = ( + [] if not self.copy_states else [self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[0][self.name].mx] + ) self.mx_states_dot = ( - [] if not self.copy_states_dot else [self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0][self.name].mx] + [] + if not self.copy_states_dot + else [self.ocp.nlp[self.nlp.use_states_dot_from_phase_idx].states_dot[0][self.name].mx] + ) + self.mx_controls = ( + [] + if not self.copy_controls + else [self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0][self.name].mx] ) - self.mx_controls = [] if not self.copy_controls else [self.ocp.nlp[self.nlp.use_controls_from_phase_idx].controls[0][self.name].mx] self.mx_stochastic = [] # todo: if mapping on variables, what do we do with mapping on the nodes for i in self.nlp.variable_mappings[self.name].to_second.map_idx: - var_name = f"{'-' if np.sign(i) < 0 else ''}{self.name}_{self.name_elements[abs(i)]}_MX" if i is not None else "zero" + var_name = ( + f"{'-' if np.sign(i) < 0 else ''}{self.name}_{self.name_elements[abs(i)]}_MX" + if i is not None + else "zero" + ) if not self.copy_states: self.mx_states.append(MX.sym(var_name, 1, 1)) @@ -225,12 +246,12 @@ def _use_copy(self): self.mx_stochastic = vertcat(*self.mx_stochastic) def _declare_auto_axes_idx(self): - """ Declare the axes index if not already declared """ + """Declare the axes index if not already declared""" if not self.axes_idx: self.axes_idx = BiMapping(to_first=range(len(self.name_elements)), to_second=range(len(self.name_elements))) def _declare_legend(self): - """ Declare the legend if not already declared """ + """Declare the legend if not already declared""" self.legend = [] for idx, name_el in enumerate(self.name_elements): if idx is not None and idx in self.axes_idx.to_first.map_idx: @@ -245,7 +266,11 @@ def _declare_legend(self): def _declare_cx_and_plot(self): if self.as_states: for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1): - n_cx = self.nlp.ode_solver.polynomial_degree + 2 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3 + n_cx = ( + self.nlp.ode_solver.polynomial_degree + 2 + if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) + else 3 + ) cx_scaled = ( self.ocp.nlp[self.nlp.use_states_from_phase_idx].states[node_index][self.name].original_cx if self.copy_states @@ -256,7 +281,9 @@ def _declare_cx_and_plot(self): if self.copy_states else self.define_cx_unscaled(cx_scaled, self.nlp.x_scaling[self.name].scaling) ) - self.nlp.states.append(self.name, cx[0], cx_scaled[0], self.mx_states, self.nlp.variable_mappings[self.name], node_index) + self.nlp.states.append( + self.name, cx[0], cx_scaled[0], self.mx_states, self.nlp.variable_mappings[self.name], node_index + ) if not self.skip_plot: self.nlp.plot[f"{self.name}_states"] = CustomPlot( lambda t, x, u, p, s: x[self.nlp.states[self.name].index, :], @@ -284,7 +311,9 @@ def _declare_cx_and_plot(self): if self.copy_controls else self.define_cx_unscaled(cx_scaled, self.nlp.u_scaling[self.name].scaling) ) - self.nlp.controls.append(self.name, cx[0], cx_scaled[0], self.mx_controls, self.nlp.variable_mappings[self.name], node_index) + self.nlp.controls.append( + self.name, cx[0], cx_scaled[0], self.mx_controls, self.nlp.variable_mappings[self.name], node_index + ) plot_type = PlotType.PLOT if self.nlp.control_type == ControlType.LINEAR_CONTINUOUS else PlotType.STEP if not self.skip_plot: @@ -293,12 +322,18 @@ def _declare_cx_and_plot(self): plot_type=plot_type, axes_idx=self.axes_idx, legend=self.legend, - combine_to=f"{self.name}_states" if self.as_states and self.combine_state_control_plot else self.combine_name, + combine_to=f"{self.name}_states" + if self.as_states and self.combine_state_control_plot + else self.combine_name, ) if self.as_states_dot: for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1): - n_cx = self.nlp.ode_solver.polynomial_degree + 1 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3 + n_cx = ( + self.nlp.ode_solver.polynomial_degree + 1 + if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) + else 3 + ) if n_cx < 3: n_cx = 3 cx_scaled = ( @@ -311,16 +346,32 @@ def _declare_cx_and_plot(self): if self.copy_states_dot else self.define_cx_unscaled(cx_scaled, self.nlp.xdot_scaling[self.name].scaling) ) - self.nlp.states_dot.append(self.name, cx[0], cx_scaled[0], self.mx_states_dot, self.nlp.variable_mappings[self.name], node_index) + self.nlp.states_dot.append( + self.name, + cx[0], + cx_scaled[0], + self.mx_states_dot, + self.nlp.variable_mappings[self.name], + node_index, + ) if self.as_stochastic: for node_index in range((0 if self.ocp.assume_phase_dynamics else self.nlp.ns) + 1): - n_cx = self.nlp.ode_solver.polynomial_degree + 1 if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) else 3 + n_cx = ( + self.nlp.ode_solver.polynomial_degree + 1 + if isinstance(self.nlp.ode_solver, OdeSolver.COLLOCATION) + else 3 + ) if n_cx < 3: n_cx = 3 cx_scaled = self.define_cx_scaled(n_col=n_cx, n_shooting=1, initial_node=node_index) self.nlp.stochastic_variables.append( - self.name, cx_scaled[0], cx_scaled[0], self.mx_stochastic, self.nlp.variable_mappings[self.name], node_index + self.name, + cx_scaled[0], + cx_scaled[0], + self.mx_stochastic, + self.nlp.variable_mappings[self.name], + node_index, ) @@ -410,9 +461,7 @@ def _manage_fatigue_to_new_variable( color=color[i], ) elif i == 0: - NewVariableConfiguration( - f"{name}", name_elements, ocp, nlp, as_states, as_controls, skip_plot=True - ) + NewVariableConfiguration(f"{name}", name_elements, ocp, nlp, as_states, as_controls, skip_plot=True) nlp.plot[f"{name}_controls"] = CustomPlot( lambda t, x, u, p, s, key: u[nlp.controls[key].index, :], plot_type=PlotType.STEP,