diff --git a/README.md b/README.md index d00e444..b3f04fd 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ If tensorflow or onnx is to be used as inference engine, the yaml file should be model: name: optional_default_name_of_the_network_for_the_storage_path label: optional default label of the network for the plots + color: optional color of the network for the plots version: optional_version_number # (e.g. "1.0.0") inference_engine: name_of_inference_engine # (either "tf" or "onnx") file: path_to_your_pb_or_onnx_model_file @@ -153,6 +154,7 @@ Hence, the yaml file should be in the following format: model: name: optional_default_name_of_the_network_for_the_storage_path label: optional default label of the network for the plots + color: optional color of the network for the plots version: version_number # (e.g. "1.0.0") inference_engine: tfaot saved_model: path_to_your_saved_model_directory @@ -264,7 +266,7 @@ This task merges the .csv output files with the required multiple batch sizes fr - The .csv files from the several occurences of `MeasureRuntime` (one for each batch size). ## Parameters: -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - model-file: str. The absolute path of the yaml file containing the informations of the model to be tested. default: `$MLP_BASE/examples/dnn/model_tf_l10u128.yaml`. @@ -308,21 +310,23 @@ The number of inferences behind one plotted data point is given by `n-events * n - The .csv file from the `MergeRuntimes` task. ## Parameters: -- y-log: bool. Plot the y-axis values logarithmically; default: `False`. +- y-log: bool. Plot the y-axis values logarithmically. default: `False`. -- x-log: bool. Plot the x-axis values logarithmically; default: `False`. +- x-log: bool. Plot the x-axis values logarithmically. default: `False`. - y-min = float. Minimum y-axis value. default: empty - y-max: float. Maximum y-axis value. default: empty -- bs-normalized: bool. Normalize the measured values with the batch size before plotting; default: `True`. +- bs-normalized: bool. Normalize the measured values with the batch size before plotting. default: `True`. -- error-style: str. Style of errors / uncertainties due to averaging; choices: `bars`,`band`; default: `band`. +- error-style: str. Style of errors / uncertainties due to averaging. choices: `bars`,`band`. default: `band`. - top-right-label: str. When set, stick this string as label over the top right corner of the plot. default: empty. -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- default_colors: str. Default color cycle to use for plots. choices: `mpl`, `cms_6`, `atlas_10`. default: `cms_6`. + +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - model-file: str. The absolute path of the yaml file containing the informations of the model to be tested. default: `$MLP_BASE/examples/dnn/model_tf_l10u128.yaml`. @@ -406,21 +410,23 @@ The number of inferences behind one plotted data point is given by `n-events * n - model-labels: str. The comma-separated list of model labels. When set, use these strings for the model labels in the plots from the plotting tasks. When empty, the `label` fields in the models yaml data are used when existing, else the `name` fields in the models yaml data are used when existing, and model-names otherwise. default: empty. -- y-log: bool. Plot the y-axis values logarithmically; default: `False`. +- y-log: bool. Plot the y-axis values logarithmically. default: `False`. -- x-log: bool. Plot the x-axis values logarithmically; default: `False`. +- x-log: bool. Plot the x-axis values logarithmically. default: `False`. - y-min = float. Minimum y-axis value. default: empty - y-max: float. Maximum y-axis value. default: empty -- bs-normalized: bool. Normalize the measured values with the batch size before plotting; default: `True`. +- bs-normalized: bool. Normalize the measured values with the batch size before plotting. default: `True`. -- error-style: str. Style of errors / uncertainties due to averaging; choices: `bars`,`band`; default: `band`. +- error-style: str. Style of errors / uncertainties due to averaging. choices: `bars`,`band`. default: `band`. - top-right-label: str. When set, stick this string as label over the top right corner of the plot. default: empty. -- batch-sizes: int. The comma-separated list of batch sizes to be tested; default: `1,2,4`. +- default_colors: str. Default color cycle to use for plots. choices: `mpl`, `cms_6`, `atlas_10`. default: `cms_6`. + +- batch-sizes: int. The comma-separated list of batch sizes to be tested. default: `1,2,4`. - n-events: int. The number of events to read from each input file for averaging measurements. default: `1` diff --git a/mlprof/plotting/plotter.py b/mlprof/plotting/plotter.py index 15fb04a..571a0f3 100644 --- a/mlprof/plotting/plotter.py +++ b/mlprof/plotting/plotter.py @@ -157,9 +157,9 @@ def plot_batch_size_several_measurements( # create plot with curves using a single color for each value-error pair legend_entries = [] - if plot_params.get("custom_colors"): + if plot_params.get("default_colors"): # set the color cycle to the custom color cycle - ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("custom_colors")])) + ax._get_lines.set_prop_cycle(cycler("color", colors[plot_params.get("default_colors")])) for i, data in enumerate(plot_data): color_used = color_list[i] if color_list[i] else ax._get_lines.get_next_color() diff --git a/mlprof/tasks/parameters.py b/mlprof/tasks/parameters.py index 358f6f8..7d41aee 100644 --- a/mlprof/tasks/parameters.py +++ b/mlprof/tasks/parameters.py @@ -313,7 +313,7 @@ class CustomPlotParameters(BaseTask): significant=False, description="stick a label over the top right corner of the plot", ) - custom_colors = luigi.ChoiceParameter( + default_colors = luigi.ChoiceParameter( choices=["mpl", "cms_6", "atlas_10"], default="cms_6", significant=False, @@ -331,5 +331,5 @@ def custom_plot_params(self): "bs_normalized": self.bs_normalized, "error_style": self.error_style, "top_right_label": None if self.top_right_label == law.NO_STR else self.top_right_label, - "custom_colors": self.custom_colors, + "default_colors": self.default_colors, }