From 59735d072b41d3b84512fbc7864e1838c8721bf0 Mon Sep 17 00:00:00 2001 From: pang Date: Mon, 20 Mar 2023 15:47:44 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8D=E7=BB=98=E5=9B=BE?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E4=B8=AD=E7=9A=84=E4=B8=80=E4=BA=9B=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E3=80=82=202.=20=E5=AE=8C=E5=96=84=E5=AF=B9=E4=BA=8E?= =?UTF-8?q?=20E-measure=20=E7=BB=98=E5=9B=BE=E7=9A=84=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E3=80=82=203.=20=E8=A1=A5=E5=85=85=E4=B8=80=E4=BA=9B=E7=BB=98?= =?UTF-8?q?=E5=9B=BE=E7=9A=84=E5=B1=95=E7=A4=BA=EF=BC=8C=E8=BF=99=E9=87=8C?= =?UTF-8?q?=E4=BB=A5=E6=88=91=E8=87=AA=E5=B7=B1=E7=9A=84=20RGB-D=20SOD=20?= =?UTF-8?q?=E8=AE=BA=E6=96=87=20CAVER=20(TIP=202023)=20=E7=9A=84=E8=AE=BA?= =?UTF-8?q?=E6=96=87=E7=BB=93=E6=9E=9C=E4=B8=BA=E4=BE=8B=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + metrics/draw_curves.py | 35 +++++++++++++++++++++++++---------- plot.py | 11 +++++++++-- readme.md | 20 +++++++++++++++++++- utils/print_formatter.py | 4 +++- 5 files changed, 57 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 6e12f29..39937a9 100755 --- a/.gitignore +++ b/.gitignore @@ -282,3 +282,4 @@ gen /*.sh /results/rgb_sod.md /results/htmls/*.html +!/.github/assets/*.jpg diff --git a/metrics/draw_curves.py b/metrics/draw_curves.py index f6d3333..91a123b 100644 --- a/metrics/draw_curves.py +++ b/metrics/draw_curves.py @@ -9,7 +9,7 @@ def draw_curves( - for_pr: bool = True, + mode: str, axes_setting: dict = None, curves_npy_path: list = None, row_num: int = 1, @@ -20,14 +20,13 @@ def draw_curves( ncol_of_legend: int = 1, separated_legend: bool = False, sharey: bool = False, - line_styles=("-", "--"), line_width=3, save_name=None, ): """A better curve painter! Args: - for_pr (bool, optional): Plot for PR curves or FM curves. Defaults to True. + mode (str): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves. axes_setting (dict, optional): Setting for axes. Defaults to None. curves_npy_path (list, optional): Paths of curve npy files. Defaults to None. row_num (int, optional): Number of rows. Defaults to 1. @@ -38,11 +37,10 @@ def draw_curves( ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1. separated_legend (bool, optional): Use the separated legend. Defaults to False. sharey (bool, optional): Use a shared y-axis. Defaults to False. - line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--"). line_width (int, optional): Width of lines. Defaults to 3. save_name (str, optional): Name or path (without the extension format). Defaults to None. """ - mode = "pr" if for_pr else "fm" + assert mode in ["pr", "fm", "em"] save_name = save_name or mode mode_axes_setting = axes_setting[mode] @@ -97,8 +95,9 @@ def draw_curves( # assert len(our_methods) <= len(line_styles) else: our_methods = [] + num_our_methods = len(our_methods) - # Give each method a unique color. + # Give each method a unique color and style. color_table = sorted( [ color @@ -106,14 +105,26 @@ def draw_curves( if name not in ["red", "white"] or not name.startswith("light") or "gray" in name ] ) + style_table = ["-", "--", "-.", ":", "."] + unique_method_settings = OrderedDict() for i, method_name in enumerate(target_unique_method_names): + if i < num_our_methods: + line_color = "red" + line_style = style_table[i % len(style_table)] + else: + other_idx = i - num_our_methods + line_color = color_table[other_idx] + line_style = style_table[other_idx % 2] + unique_method_settings[method_name] = { - "line_color": "red" if i < len(our_methods) else color_table[i], + "line_color": line_color, "line_label": method_aliases.get(method_name, method_name), - "line_style": line_styles[i % len(line_styles)], + "line_style": line_style, "line_width": line_width, } + # ensure that our methods are drawn last to avoid being overwritten by other methods + target_unique_method_names.reverse() curve_drawer = CurveDrawer( row_num=row_num, @@ -135,9 +146,13 @@ def draw_curves( y_ticks=y_ticks, ) - for method_name, method_setting in unique_method_settings.items(): + for method_name in target_unique_method_names: + method_setting = unique_method_settings[method_name] + if method_name not in dataset_results: - raise KeyError(f"{method_name} not in {sorted(dataset_results.keys())}") + print(f"{method_name} will be skipped for {dataset_name}!") + continue + method_results = dataset_results[method_name] if mode == "pr": y_data = method_results.get("p") diff --git a/plot.py b/plot.py index 0f97a08..7bf2437 100644 --- a/plot.py +++ b/plot.py @@ -73,7 +73,7 @@ def get_args(): parser.add_argument( "--mode", type=str, - choices=["pr", "fm"], + choices=["pr", "fm", "em"], default="pr", help="Mode for plotting. Default: pr", ) @@ -96,7 +96,7 @@ def main(args): dataset_aliases = aliases.get("dataset") draw_curves.draw_curves( - for_pr=args.mode == "pr", + mode=args.mode, # 不同曲线的绘图配置 axes_setting={ # pr曲线的配置 @@ -113,6 +113,13 @@ def main(args): "x_ticks": np.linspace(0, 1, 6), "y_ticks": np.linspace(0.6, 1, 6), }, + # em曲线的配置 + "em": { + "x_label": "Threshold", + "y_label": r"E$_{m}$", + "x_ticks": np.linspace(0, 1, 6), + "y_ticks": np.linspace(0.7, 1, 6), + }, }, curves_npy_path=args.curves_npys, row_num=args.num_rows, diff --git a/readme.md b/readme.md index 44e1ff1..9f66336 100644 --- a/readme.md +++ b/readme.md @@ -176,7 +176,7 @@ A Python-based image binary segmentation evaluation toolbox. ### 为灰度图像的评估绘制曲线 -可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` 曲线和 `Fm` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可. +可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` , `F-measure` 和 `E-measure` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可. 最基本的一条是请按照子图数量, 合理地指定配置文件中的 `figure.figsize` 项的数值. @@ -223,6 +223,20 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --our-methods MINet_R50_2020 --num-col-legend 1 --mode pr --separated-legend --sharey --save-name output/rgb_sod/complex_curve_pr ``` +## 绘图示例 + +**Precision-Recall Curve**: + +![PRCurves](https://user-images.githubusercontent.com/26847524/227249768-a41ef076-6355-4b96-a291-fc0e071d9d35.jpg) + +**F-measure Curve**: + +![fm-curves](https://user-images.githubusercontent.com/26847524/227249746-f61d7540-bb73-464d-bccf-9a36323dec47.jpg) + +**E-measure Curve**: + +![em-curves](https://user-images.githubusercontent.com/26847524/227249727-8323d5cf-ddd7-427b-8152-b8f47781c4e3.jpg) + ## 相关文献 ```text @@ -282,6 +296,10 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n ## 更新日志 +* 2023年3月23日 + 1. 修复绘图代码中的一些问题。 + 2. 完善对于 E-measure 绘图的支持。 + 3. 补充一些绘图的展示,这里以我自己的 RGB-D SOD 论文 CAVER (TIP 2023) 的论文结果为例。 * 2023年3月20日 1. 提供更丰富的指标的支持。 2. 更新`readme.md`和示例文件。 diff --git a/utils/print_formatter.py b/utils/print_formatter.py index c5196c0..e1405db 100644 --- a/utils/print_formatter.py +++ b/utils/print_formatter.py @@ -90,7 +90,9 @@ def formatter_for_tabulate( table = [] headers = ["methods"] for method_name in method_names: - metric_info = dataset_metrics[method_name] + metric_info = dataset_metrics.get(method_name) + if metric_info is None: + continue if method_name_length: method_name = clip_string(method_name, max_length=method_name_length, mode="left")