From edae064a1c483f8471c9257c27fda843998eab42 Mon Sep 17 00:00:00 2001 From: Nathan Prouvost Date: Fri, 2 Aug 2024 18:18:31 +0200 Subject: [PATCH] solve numpy error for numpy above 2.0, set fix version values in plotting sandbox, add color cycle option --- mlprof/plotting/plotter.py | 32 ++++++++++++++++++++++++++++---- mlprof/tasks/parameters.py | 9 +++++++++ sandboxes/plotting.txt | 10 +++++----- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/mlprof/plotting/plotter.py b/mlprof/plotting/plotter.py index 0e59858..8e76a0d 100644 --- a/mlprof/plotting/plotter.py +++ b/mlprof/plotting/plotter.py @@ -1,11 +1,30 @@ # coding: utf-8 colors = { - "mpl_standard": [ + "mpl": [ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", ], - "custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"], - "custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"], + # Atlas and cms standards correspond to results in : + # Color wheel from https://arxiv.org/pdf/2107.02270 Table 1, 10 color palette + # hexacodes in https://github.com/mpetroff/accessible-color-cycles/blob/0a17e754d9f83161baffd803dcea8bee7d95a549/readme.md#final-results # noqa + # as implemented in mplhep + "cms_6": [ + "#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd", + ], + "atlas_10": [ + "#3f90da", + "#ffa90e", + "#bd1f01", + "#94a4a2", + "#832db6", + "#a96b59", + "#e76300", + "#b9ac70", + "#717581", + "#92dadd", + ], + # "custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"], + # "custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"], } @@ -86,7 +105,7 @@ def fill_plot(x, y, y_down, y_up, error_style, color): if error_style == "band": p1 = plt.plot(x, y, "-", color=color) plt.fill_between(x, y - y_down, y + y_up, alpha=0.5, facecolor=color) - p2 = plt.fill(np.NaN, np.NaN, alpha=0.5, color=color) + p2 = plt.fill(np.nan, np.nan, alpha=0.5, color=color) legend = (p1[0], p2[0]) else: # bars p = plt.errorbar(x, y, yerr=(y_down, y_up), capsize=12, marker=".", linestyle="") @@ -114,6 +133,7 @@ def plot_batch_size_several_measurements( """ import matplotlib.pyplot as plt import mplhep # type: ignore[import-untyped] + from cycler import cycler if isinstance(measurements[0], str): measurements_labels_strs = list(measurements) @@ -136,6 +156,10 @@ def plot_batch_size_several_measurements( # create plot with curves using a single color for each value-error pair legend_entries = [] + if plot_params.get("custom_colors"): + # set the color cycle to the custom color cycle + ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("custom_colors")])) + for data in plot_data: entry = fill_plot( x=batch_sizes, diff --git a/mlprof/tasks/parameters.py b/mlprof/tasks/parameters.py index 01a0a3c..358f6f8 100644 --- a/mlprof/tasks/parameters.py +++ b/mlprof/tasks/parameters.py @@ -211,6 +211,7 @@ class MultiModelParameters(BaseTask): description="when set, use these labels in plots; when empty, the `label` fields in the models " "yaml data are used when existing, else the `name` fields in the models yaml data are used when " "existing and model-names otherwise; default: empty", + brace_expand=True, ) def __init__(self, *args, **kwargs): @@ -312,6 +313,13 @@ class CustomPlotParameters(BaseTask): significant=False, description="stick a label over the top right corner of the plot", ) + custom_colors = luigi.ChoiceParameter( + choices=["mpl", "cms_6", "atlas_10"], + default="cms_6", + significant=False, + description="default color cycle to use; choices: 'mpl', 'cms_6', 'atlas_10'" + "; default: 'cms_6'", + ) @property def custom_plot_params(self): @@ -323,4 +331,5 @@ def custom_plot_params(self): "bs_normalized": self.bs_normalized, "error_style": self.error_style, "top_right_label": None if self.top_right_label == law.NO_STR else self.top_right_label, + "custom_colors": self.custom_colors, } diff --git a/sandboxes/plotting.txt b/sandboxes/plotting.txt index 7a2a915..dff287b 100644 --- a/sandboxes/plotting.txt +++ b/sandboxes/plotting.txt @@ -1,6 +1,6 @@ -# version 1 +# version 2 -numpy -pandas -matplotlib -mplhep +numpy~=2.0.1 +pandas~=2.2.2 +matplotlib~=3.9.1 +mplhep~=0.3.50