Skip to content

Commit

Permalink
* 2022年4月23日
Browse files Browse the repository at this point in the history
    - 为了便于使用和配置,对大量代码进行了调整和修改,与之前版本相比,使用上也存在部分差异。
    - 评估部分:
      - 支持多个方法的json文件同时使用评估。
      - 更新了指标统计类,便于更灵活的指定不同的指标。
      - 对一些医学二值分割的指标提供了支持。
    - 绘图部分:
      - 支持多个曲线npy文件同时用于绘图。
      - 将个性化配置尽可能独立出来,提供了独立的绘图配置文件。
      - 重构了绘图类,便于使用yaml文件对matplotlib的默认设定进行覆盖。
  • Loading branch information
lartpang committed Apr 24, 2022
1 parent 897612b commit 2a9cd17
Show file tree
Hide file tree
Showing 14 changed files with 756 additions and 154 deletions.
8 changes: 6 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Big files
**/*.png
**/*.pdf
**/*.jpg
**/*.bmp
**/*.zip
**/*.7z
**/*.rar
**/*.tar*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -274,7 +278,7 @@ gen
/output/
/untracked/
/configs/
/*.py
/*.sh
# /*.py
# /*.sh
/results/rgb_sod.md
/results/htmls/*.html
153 changes: 153 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
import argparse
import os
import textwrap
import warnings

from metrics import cal_sod_matrics
from utils.generate_info import get_datasets_info, get_methods_info
from utils.misc import make_dir
from utils.recorders import METRIC_MAPPING


def get_args():
parser = argparse.ArgumentParser(
description=textwrap.dedent(
r"""
INCLUDE:
- F-measure-Threshold Curve
- Precision-Recall Curve
- MAE
- weighted F-measure
- S-measure
- max/average/adaptive F-measure
- max/average/adaptive E-measure
- max/average Precision
- max/average Sensitivity
- max/average Specificity
- max/average F-measure
- max/average Dice
- max/average IoU
NOTE:
- Our method automatically calculates the intersection of `pre` and `gt`.
- Currently supported pre naming rules: `prefix + gt_name_wo_ext + suffix_w_ext`
EXAMPLES:
python eval_all.py \
--dataset-json configs/datasets/json/rgbd_sod.json \
--method-json configs/methods/json/rgbd_other_methods.json configs/methods/json/rgbd_our_method.json --metric-npy output/rgbd_metrics.npy \
--curves-npy output/rgbd_curves.npy \
--record-tex output/rgbd_results.txt
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
parser.add_argument(
"--method-json", required=True, nargs="+", type=str, help="Json file for methods."
)
parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
parser.add_argument(
"--include-methods",
type=str,
nargs="+",
help="Names of only specific methods you want to evaluate.",
)
parser.add_argument(
"--exclude-methods",
type=str,
nargs="+",
help="Names of some specific methods you do not want to evaluate.",
)
parser.add_argument(
"--include-datasets",
type=str,
nargs="+",
help="Names of only specific datasets you want to evaluate.",
)
parser.add_argument(
"--exclude-datasets",
type=str,
nargs="+",
help="Names of some specific datasets you do not want to evaluate.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for multi-threading or multi-processing. Default: 4",
)
parser.add_argument(
"--num-bits",
type=int,
default=3,
help="Number of decimal places for showing results. Default: 3",
)
parser.add_argument(
"--metric-names",
type=str,
nargs="+",
default=["mae", "fm", "em", "sm", "wfm"],
choices=METRIC_MAPPING.keys(),
help="Names of metrics",
)
args = parser.parse_args()

if args.metric_npy is not None:
make_dir(os.path.dirname(args.metric_npy))
if args.curves_npy is not None:
make_dir(os.path.dirname(args.curves_npy))
if args.record_txt is not None:
make_dir(os.path.dirname(args.record_txt))
if args.record_xlsx is not None:
make_dir(os.path.dirname(args.record_xlsx))
if args.to_overwrite and not args.record_txt:
warnings.warn("--to-overwrite only works with a valid --record-txt")
return args


def main():
args = get_args()

# 包含所有数据集信息的字典
datasets_info = get_datasets_info(
datastes_info_json=args.dataset_json,
include_datasets=args.include_datasets,
exclude_datasets=args.exclude_datasets,
)
# 包含所有待比较模型结果的信息的字典
methods_info = get_methods_info(
methods_info_jsons=args.method_json,
for_drawing=True,
include_methods=args.include_methods,
exclude_methods=args.exclude_methods,
)

# 确保多进程在windows上也可以正常使用
cal_sod_matrics.cal_sod_matrics(
sheet_name="Results",
to_append=not args.to_overwrite,
txt_path=args.record_txt,
xlsx_path=args.record_xlsx,
methods_info=methods_info,
datasets_info=datasets_info,
curves_npy_path=args.curves_npy,
metrics_npy_path=args.metric_npy,
num_bits=args.num_bits,
num_workers=args.num_workers,
use_mp=False,
metric_names=args.metric_names,
ncols_tqdm=119,
)


if __name__ == "__main__":
main()
14 changes: 11 additions & 3 deletions metrics/cal_sod_matrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from utils.misc import get_gt_pre_with_name, get_name_list, make_dir
from utils.print_formatter import formatter_for_tabulate
from utils.recorders import MetricExcelRecorder, MetricRecorder, TxtRecorder
from utils.recorders import MetricExcelRecorder, MetricRecorder_V2, TxtRecorder


class Recorder:
Expand Down Expand Up @@ -80,6 +80,8 @@ def cal_sod_matrics(
num_bits: int = 3,
num_workers: int = 2,
use_mp: bool = False,
metric_names: tuple = ("mae", "fm", "em", "sm", "wfm"),
ncols_tqdm: int = 79,
):
"""
Save the results of all models on different datasets in a `npy` file in the form of a
Expand Down Expand Up @@ -112,6 +114,8 @@ def cal_sod_matrics(
:param num_bits: the number of bits used to format results
:param num_workers: the number of workers of multiprocessing or multithreading
:param use_mp: using multiprocessing or multithreading
:param metric_names: names of metrics
:param ncols_tqdm: number of columns for tqdm
"""
recorder = Recorder(
txt_path=txt_path,
Expand Down Expand Up @@ -181,6 +185,8 @@ def cal_sod_matrics(
desc=f"[{dataset_name}({len(gt_name_list)}):{method_name}({len(pre_name_list)})]",
proc_idx=procs_idx,
blocking=use_mp,
metric_names=metric_names,
ncols_tqdm=ncols_tqdm,
),
callback=partial(recorder.record, method_name=method_name),
)
Expand Down Expand Up @@ -211,16 +217,18 @@ def evaluate_data(
desc="",
proc_idx=None,
blocking=True,
metric_names=None,
ncols_tqdm=79,
):
metric_recoder = MetricRecorder()
metric_recoder = MetricRecorder_V2(metric_names=metric_names)
# https://github.com/tqdm/tqdm#parameters
# https://github.com/tqdm/tqdm/blob/master/examples/parallel_bars.py
tqdm_bar = tqdm(
names,
total=len(names),
desc=desc,
position=proc_idx,
ncols=79,
ncols=ncols_tqdm,
lock_args=None if blocking else (False,),
)
for name in tqdm_bar:
Expand Down
Loading

0 comments on commit 2a9cd17

Please sign in to comment.