diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index f8736e899..10bcfc223 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -4,7 +4,7 @@ import numpy as np from casadi import DM from matplotlib import pyplot as plt, lines -from matplotlib.ticker import StrMethodFormatter +from matplotlib.ticker import FuncFormatter from .serializable_class import OcpSerializable from ..dynamics.ode_solver import OdeSolver @@ -180,9 +180,9 @@ class PlotOcp: ------- _update_time_vector(self) Setup the time and time integrated vector, which is the x-axes of the graphs - __create_plots(self) + _create_plots(self) Setup the plots - __add_new_axis(self, variable: str, nb: int, n_rows: int, n_cols: int) + _add_new_axis(self, variable: str, nb: int, n_rows: int, n_cols: int) Add a new axis to the axes pool _organize_windows(self, n_windows: int) Automatically organize the figure across the screen. @@ -192,11 +192,11 @@ class PlotOcp: Force the show of the graphs. This is a blocking function update_data(self, v: dict) Update ydata from the variable a solution structure - __update_xdata(self) + _update_xdata(self) Update of the time axes in plots - __update_axes(self) + _update_axes(self) Update the plotted data from ydata - __compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: float) -> tuple: + _compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: float) -> tuple: Dynamically find the ylim _generate_windows_size(nb: int) -> tuple[int, int] Defines the number of column and rows of subplots from the number of variables to plot. @@ -492,7 +492,7 @@ def legend_without_duplicate_labels(ax): if y_max.__array__()[0] > y_max_all[y_range_var_idx][mapping_to_first_index.index(ctr)]: y_max_all[y_range_var_idx][mapping_to_first_index.index(ctr)] = y_max - y_range, _ = self.__compute_ylim( + y_range = self._compute_ylim( y_min_all[y_range_var_idx][mapping_to_first_index.index(ctr)], y_max_all[y_range_var_idx][mapping_to_first_index.index(ctr)], 1.25, @@ -651,7 +651,8 @@ def _add_new_axis(self, variable: str, nb: int, n_rows: int, n_cols: int): self.all_figures[-1].tight_layout() for ax in axes: - ax.yaxis.set_major_formatter(StrMethodFormatter("{x:,.1f}")) # 1 decimal places + ax.yaxis.set_major_formatter(FuncFormatter(lambda value, tick_value: f"{value:.2f}")) + return axes def _organize_windows(self, n_windows: int): @@ -996,15 +997,8 @@ def _update_ydata(self, ydata): if isinstance(p, lines.Line2D): y_min = min(y_min, np.nanmin(p.get_ydata())) y_max = max(y_max, np.nanmax(p.get_ydata())) - y_range, data_range = self.__compute_ylim(y_min, y_max, 1.25) - ax.set_ylim(y_range) - ax.set_yticks( - np.arange( - y_range[0], - y_range[1], - step=data_range / 4, - ) - ) + ax.set_ylim(self._compute_ylim(y_min, y_max, 1.25)) + for p in self.plots_vertical_lines: p.set_ydata((0, 1)) @@ -1013,7 +1007,7 @@ def _update_ydata(self, ydata): # TODO: set_tight_layout function will be deprecated. Use set_layout_engine instead. @staticmethod - def __compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: float) -> tuple: + def _compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: float) -> tuple: """ Dynamically find the ylim @@ -1040,8 +1034,7 @@ def __compute_ylim(min_val: np.ndarray | DM, max_val: np.ndarray | DM, factor: f if np.abs(data_range) < 0.8: data_range = 0.8 y_range = (factor * data_range) / 2 - y_range = data_mean - y_range, data_mean + y_range - return y_range, data_range + return data_mean - y_range, data_mean + y_range @staticmethod def _generate_windows_size(nb: int) -> tuple: