Skip to content

Commit

Permalink
solve numpy error for numpy above 2.0, set fix version values in plot…
Browse files Browse the repository at this point in the history
…ting sandbox, add color cycle option
  • Loading branch information
nprouvost committed Aug 2, 2024
1 parent bd02706 commit edae064
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
32 changes: 28 additions & 4 deletions mlprof/plotting/plotter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
# coding: utf-8

colors = {
"mpl_standard": [
"mpl": [
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf",
],
"custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"],
"custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"],
# Atlas and cms standards correspond to results in :
# Color wheel from https://arxiv.org/pdf/2107.02270 Table 1, 10 color palette
# hexacodes in https://github.com/mpetroff/accessible-color-cycles/blob/0a17e754d9f83161baffd803dcea8bee7d95a549/readme.md#final-results # noqa
# as implemented in mplhep
"cms_6": [
"#5790fc", "#f89c20", "#e42536", "#964a8b", "#9c9ca1", "#7a21dd",
],
"atlas_10": [
"#3f90da",
"#ffa90e",
"#bd1f01",
"#94a4a2",
"#832db6",
"#a96b59",
"#e76300",
"#b9ac70",
"#717581",
"#92dadd",
],
# "custom_edgecolor": ["#CC4F1B", "#1B2ACC", "#3F7F4C"],
# "custom_facecolor": ["#FF9848", "#089FFF", "#7EFF99"],
}


Expand Down Expand Up @@ -86,7 +105,7 @@ def fill_plot(x, y, y_down, y_up, error_style, color):
if error_style == "band":
p1 = plt.plot(x, y, "-", color=color)
plt.fill_between(x, y - y_down, y + y_up, alpha=0.5, facecolor=color)
p2 = plt.fill(np.NaN, np.NaN, alpha=0.5, color=color)
p2 = plt.fill(np.nan, np.nan, alpha=0.5, color=color)
legend = (p1[0], p2[0])
else: # bars
p = plt.errorbar(x, y, yerr=(y_down, y_up), capsize=12, marker=".", linestyle="")
Expand Down Expand Up @@ -114,6 +133,7 @@ def plot_batch_size_several_measurements(
"""
import matplotlib.pyplot as plt
import mplhep # type: ignore[import-untyped]
from cycler import cycler

if isinstance(measurements[0], str):
measurements_labels_strs = list(measurements)
Expand All @@ -136,6 +156,10 @@ def plot_batch_size_several_measurements(

# create plot with curves using a single color for each value-error pair
legend_entries = []
if plot_params.get("custom_colors"):
# set the color cycle to the custom color cycle
ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("custom_colors")]))

for data in plot_data:
entry = fill_plot(
x=batch_sizes,
Expand Down
9 changes: 9 additions & 0 deletions mlprof/tasks/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class MultiModelParameters(BaseTask):
description="when set, use these labels in plots; when empty, the `label` fields in the models "
"yaml data are used when existing, else the `name` fields in the models yaml data are used when "
"existing and model-names otherwise; default: empty",
brace_expand=True,
)

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -312,6 +313,13 @@ class CustomPlotParameters(BaseTask):
significant=False,
description="stick a label over the top right corner of the plot",
)
custom_colors = luigi.ChoiceParameter(
choices=["mpl", "cms_6", "atlas_10"],
default="cms_6",
significant=False,
description="default color cycle to use; choices: 'mpl', 'cms_6', 'atlas_10'"
"; default: 'cms_6'",
)

@property
def custom_plot_params(self):
Expand All @@ -323,4 +331,5 @@ def custom_plot_params(self):
"bs_normalized": self.bs_normalized,
"error_style": self.error_style,
"top_right_label": None if self.top_right_label == law.NO_STR else self.top_right_label,
"custom_colors": self.custom_colors,
}
10 changes: 5 additions & 5 deletions sandboxes/plotting.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# version 1
# version 2

numpy
pandas
matplotlib
mplhep
numpy~=2.0.1
pandas~=2.2.2
matplotlib~=3.9.1
mplhep~=0.3.50

0 comments on commit edae064

Please sign in to comment.