diff --git a/mlprof/plotting/plotter.py b/mlprof/plotting/plotter.py index 8e76a0d..15fb04a 100644 --- a/mlprof/plotting/plotter.py +++ b/mlprof/plotting/plotter.py @@ -119,6 +119,7 @@ def plot_batch_size_several_measurements( input_paths, output_path, measurements, + color_list, plot_params, ): """ @@ -160,14 +161,15 @@ def plot_batch_size_several_measurements( # set the color cycle to the custom color cycle ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("custom_colors")])) - for data in plot_data: + for i, data in enumerate(plot_data): + color_used = color_list[i] if color_list[i] else ax._get_lines.get_next_color() entry = fill_plot( x=batch_sizes, y=data["y"], y_down=data["y_down"], y_up=data["y_up"], error_style=plot_params["error_style"], - color=ax._get_lines.get_next_color(), + color=color_used, ) legend_entries.append(entry) diff --git a/mlprof/tasks/runtime.py b/mlprof/tasks/runtime.py index 6bf0395..a1463df 100644 --- a/mlprof/tasks/runtime.py +++ b/mlprof/tasks/runtime.py @@ -150,6 +150,7 @@ def run(self): [self.input().path], output.path, [self.model.full_model_label], + [self.model.color], self.custom_plot_params, ) print("plot saved") @@ -280,5 +281,6 @@ def run(self): input_paths=input_paths, output_path=output.path, measurements=self.params_product_params_to_write, + color_list=[model.color for model in self.models], plot_params=self.custom_plot_params, ) diff --git a/mlprof/util.py b/mlprof/util.py index 5d10f42..7689b52 100644 --- a/mlprof/util.py +++ b/mlprof/util.py @@ -26,6 +26,7 @@ def __init__(self, model_file: str, name: str, label: str, **kwargs) -> None: self.model_file = expand_path(model_file, abs=True) self.name = name self.label = label + self._color = None # cached data self._all_data = None @@ -67,3 +68,9 @@ def full_model_label(self): # fallback to the full model name return self.full_name + + @property + def color(self): + if self._color is None: + self._color = self.data.get("color") + return self._color