Skip to content

Commit

Permalink
add option to modify model color in the yaml files
Browse files Browse the repository at this point in the history
  • Loading branch information
nprouvost committed Aug 5, 2024
1 parent edae064 commit b696661
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mlprof/plotting/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def plot_batch_size_several_measurements(
input_paths,
output_path,
measurements,
color_list,
plot_params,
):
"""
Expand Down Expand Up @@ -160,14 +161,15 @@ def plot_batch_size_several_measurements(
# set the color cycle to the custom color cycle
ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("custom_colors")]))

for data in plot_data:
for i, data in enumerate(plot_data):
color_used = color_list[i] if color_list[i] else ax._get_lines.get_next_color()
entry = fill_plot(
x=batch_sizes,
y=data["y"],
y_down=data["y_down"],
y_up=data["y_up"],
error_style=plot_params["error_style"],
color=ax._get_lines.get_next_color(),
color=color_used,
)
legend_entries.append(entry)

Expand Down
2 changes: 2 additions & 0 deletions mlprof/tasks/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def run(self):
[self.input().path],
output.path,
[self.model.full_model_label],
[self.model.color],
self.custom_plot_params,
)
print("plot saved")
Expand Down Expand Up @@ -280,5 +281,6 @@ def run(self):
input_paths=input_paths,
output_path=output.path,
measurements=self.params_product_params_to_write,
color_list=[model.color for model in self.models],
plot_params=self.custom_plot_params,
)
7 changes: 7 additions & 0 deletions mlprof/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, model_file: str, name: str, label: str, **kwargs) -> None:
self.model_file = expand_path(model_file, abs=True)
self.name = name
self.label = label
self._color = None

# cached data
self._all_data = None
Expand Down Expand Up @@ -67,3 +68,9 @@ def full_model_label(self):

# fallback to the full model name
return self.full_name

@property
def color(self):
if self._color is None:
self._color = self.data.get("color")
return self._color

0 comments on commit b696661

Please sign in to comment.