Skip to content

Commit

Permalink
1. 修复绘图代码中的一些问题。
Browse files Browse the repository at this point in the history
2. 完善对于 E-measure 绘图的支持。
3. 补充一些绘图的展示,这里以我自己的 RGB-D SOD 论文 CAVER (TIP 2023) 的论文结果为例。
  • Loading branch information
lartpang committed Mar 23, 2023
1 parent 6c97b0c commit 59735d0
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,4 @@ gen
/*.sh
/results/rgb_sod.md
/results/htmls/*.html
!/.github/assets/*.jpg
35 changes: 25 additions & 10 deletions metrics/draw_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def draw_curves(
for_pr: bool = True,
mode: str,
axes_setting: dict = None,
curves_npy_path: list = None,
row_num: int = 1,
Expand All @@ -20,14 +20,13 @@ def draw_curves(
ncol_of_legend: int = 1,
separated_legend: bool = False,
sharey: bool = False,
line_styles=("-", "--"),
line_width=3,
save_name=None,
):
"""A better curve painter!
Args:
for_pr (bool, optional): Plot for PR curves or FM curves. Defaults to True.
mode (str): `pr` for PR curves, `fm` for F-measure curves, and `em' for E-measure curves.
axes_setting (dict, optional): Setting for axes. Defaults to None.
curves_npy_path (list, optional): Paths of curve npy files. Defaults to None.
row_num (int, optional): Number of rows. Defaults to 1.
Expand All @@ -38,11 +37,10 @@ def draw_curves(
ncol_of_legend (int, optional): Number of columns for the legend. Defaults to 1.
separated_legend (bool, optional): Use the separated legend. Defaults to False.
sharey (bool, optional): Use a shared y-axis. Defaults to False.
line_styles (tuple, optional): Styles of lines. Defaults to ("-", "--").
line_width (int, optional): Width of lines. Defaults to 3.
save_name (str, optional): Name or path (without the extension format). Defaults to None.
"""
mode = "pr" if for_pr else "fm"
assert mode in ["pr", "fm", "em"]
save_name = save_name or mode
mode_axes_setting = axes_setting[mode]

Expand Down Expand Up @@ -97,23 +95,36 @@ def draw_curves(
# assert len(our_methods) <= len(line_styles)
else:
our_methods = []
num_our_methods = len(our_methods)

# Give each method a unique color.
# Give each method a unique color and style.
color_table = sorted(
[
color
for name, color in colors.cnames.items()
if name not in ["red", "white"] or not name.startswith("light") or "gray" in name
]
)
style_table = ["-", "--", "-.", ":", "."]

unique_method_settings = OrderedDict()
for i, method_name in enumerate(target_unique_method_names):
if i < num_our_methods:
line_color = "red"
line_style = style_table[i % len(style_table)]
else:
other_idx = i - num_our_methods
line_color = color_table[other_idx]
line_style = style_table[other_idx % 2]

unique_method_settings[method_name] = {
"line_color": "red" if i < len(our_methods) else color_table[i],
"line_color": line_color,
"line_label": method_aliases.get(method_name, method_name),
"line_style": line_styles[i % len(line_styles)],
"line_style": line_style,
"line_width": line_width,
}
# ensure that our methods are drawn last to avoid being overwritten by other methods
target_unique_method_names.reverse()

curve_drawer = CurveDrawer(
row_num=row_num,
Expand All @@ -135,9 +146,13 @@ def draw_curves(
y_ticks=y_ticks,
)

for method_name, method_setting in unique_method_settings.items():
for method_name in target_unique_method_names:
method_setting = unique_method_settings[method_name]

if method_name not in dataset_results:
raise KeyError(f"{method_name} not in {sorted(dataset_results.keys())}")
print(f"{method_name} will be skipped for {dataset_name}!")
continue

method_results = dataset_results[method_name]
if mode == "pr":
y_data = method_results.get("p")
Expand Down
11 changes: 9 additions & 2 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_args():
parser.add_argument(
"--mode",
type=str,
choices=["pr", "fm"],
choices=["pr", "fm", "em"],
default="pr",
help="Mode for plotting. Default: pr",
)
Expand All @@ -96,7 +96,7 @@ def main(args):
dataset_aliases = aliases.get("dataset")

draw_curves.draw_curves(
for_pr=args.mode == "pr",
mode=args.mode,
# 不同曲线的绘图配置
axes_setting={
# pr曲线的配置
Expand All @@ -113,6 +113,13 @@ def main(args):
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.6, 1, 6),
},
# em曲线的配置
"em": {
"x_label": "Threshold",
"y_label": r"E$_{m}$",
"x_ticks": np.linspace(0, 1, 6),
"y_ticks": np.linspace(0.7, 1, 6),
},
},
curves_npy_path=args.curves_npys,
row_num=args.num_rows,
Expand Down
20 changes: 19 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ A Python-based image binary segmentation evaluation toolbox.

### 为灰度图像的评估绘制曲线

可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` 曲线和 `Fm` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可.
可以使用 `plot.py` 来读取 `.npy` 文件按需对指定方法和数据集的结果整理并绘制 `PR` , `F-measure``E-measure` 曲线. 该脚本用法可见 `python plot.py --help` 的输出. 按照自己需求添加配置项并执行即可.

最基本的一条是请按照子图数量, 合理地指定配置文件中的 `figure.figsize` 项的数值.

Expand Down Expand Up @@ -223,6 +223,20 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n
python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-npys output/rgb_sod/curves.npy --our-methods MINet_R50_2020 --num-col-legend 1 --mode pr --separated-legend --sharey --save-name output/rgb_sod/complex_curve_pr
```

## 绘图示例

**Precision-Recall Curve**:

![PRCurves](https://user-images.githubusercontent.com/26847524/227249768-a41ef076-6355-4b96-a291-fc0e071d9d35.jpg)

**F-measure Curve**:

![fm-curves](https://user-images.githubusercontent.com/26847524/227249746-f61d7540-bb73-464d-bccf-9a36323dec47.jpg)

**E-measure Curve**:

![em-curves](https://user-images.githubusercontent.com/26847524/227249727-8323d5cf-ddd7-427b-8152-b8f47781c4e3.jpg)

## 相关文献

```text
Expand Down Expand Up @@ -282,6 +296,10 @@ python plot.py --style-cfg examples/single_row_style.yml --num-rows 1 --curves-n

## 更新日志

* 2023年3月23日
1. 修复绘图代码中的一些问题。
2. 完善对于 E-measure 绘图的支持。
3. 补充一些绘图的展示,这里以我自己的 RGB-D SOD 论文 CAVER (TIP 2023) 的论文结果为例。
* 2023年3月20日
1. 提供更丰富的指标的支持。
2. 更新`readme.md`和示例文件。
Expand Down
4 changes: 3 additions & 1 deletion utils/print_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def formatter_for_tabulate(
table = []
headers = ["methods"]
for method_name in method_names:
metric_info = dataset_metrics[method_name]
metric_info = dataset_metrics.get(method_name)
if metric_info is None:
continue

if method_name_length:
method_name = clip_string(method_name, max_length=method_name_length, mode="left")
Expand Down

0 comments on commit 59735d0

Please sign in to comment.