diff --git a/configs/detection/_base_/default_runtime.py b/configs/detection/_base_/default_runtime.py index e58270f..9ee6765 100644 --- a/configs/detection/_base_/default_runtime.py +++ b/configs/detection/_base_/default_runtime.py @@ -6,7 +6,8 @@ param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), sampler_seed=dict(type='DistSamplerSeedHook'), - visualization=dict(type='DetVisualizationHook')) + visualization=dict(type='DetVisualizationHook'), + summarizer=dict(type='lqit.SummarizeHook')) randomness = dict(seed=None, deterministic=False) diff --git a/lqit/common/engine/hooks/__init__.py b/lqit/common/engine/hooks/__init__.py index 3e74c27..8287c43 100644 --- a/lqit/common/engine/hooks/__init__.py +++ b/lqit/common/engine/hooks/__init__.py @@ -1,3 +1,4 @@ from .lark_hook import LarkHook +from .summarize_hook import SummarizeHook -__all__ = ['LarkHook'] +__all__ = ['LarkHook', 'SummarizeHook'] diff --git a/lqit/common/engine/hooks/lark_hook.py b/lqit/common/engine/hooks/lark_hook.py index f9a19a0..8ef7ba6 100644 --- a/lqit/common/engine/hooks/lark_hook.py +++ b/lqit/common/engine/hooks/lark_hook.py @@ -139,7 +139,6 @@ def before_train(self, runner) -> None: def before_test(self, runner) -> None: if self.silent: return - # TODO: Check title = 'Task Initiation Report' content = f"{self.user_name}'s task has started testing!\n" \ f'Config file: {self.cfg_file}\n' \ diff --git a/lqit/common/engine/hooks/summarize_hook.py b/lqit/common/engine/hooks/summarize_hook.py new file mode 100644 index 0000000..1a33bcb --- /dev/null +++ b/lqit/common/engine/hooks/summarize_hook.py @@ -0,0 +1,129 @@ +import logging +import os.path as osp +from pathlib import Path +from typing import Dict, Optional, Union + +import tabulate +from mmengine.dist import master_only +from mmengine.hooks import Hook +from mmengine.logging import print_log + +from lqit.registry import HOOKS + + +@HOOKS.register_module() +class SummarizeHook(Hook): + """Summarize Hook, saving the metrics into a csv file. + + Args: + summary_file (str): The name of the summary file. + Defaults to 'gather_results.csv'. + out_dir (str): The output directory. If not specified, it will be set + to the log directory of the runner. Defaults to None. + by_epoch (bool): Whether to save the metrics by epoch or by iteration. + Defaults to True. + """ + priority = 'VERY_LOW' + + def __init__(self, + summary_file: str = 'gather_results.csv', + out_dir: Optional[Union[str, Path]] = None, + by_epoch: bool = True): + if not summary_file.endswith('.csv'): + summary_file += '.csv' + + if out_dir is not None and not isinstance(out_dir, (str, Path)): + raise TypeError('out_dir must be a str or Path object') + self.out_dir = out_dir + + if by_epoch: + self.metric_tmpl = 'epoch_{}' + else: + self.metric_tmpl = 'iter_{}' + + self.summary_file = summary_file + self.by_epoch = by_epoch + self.header = None + self.gather_results = dict() + + def before_run(self, runner) -> None: + """Set the output directory to the log directory of the runner if + `out_dir` is not specified.""" + if self.out_dir is None: + self.out_dir = runner.log_dir + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + if self.by_epoch: + name = self.metric_tmpl.format(runner.epoch) + else: + name = self.metric_tmpl.format(runner.iter) + self.process_metrics(name, metrics) + + def after_test_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + # name set as the checkpoint name + ckpt_path = runner._load_from + name = osp.basename(ckpt_path).split('.')[0] + self.process_metrics(name, metrics) + + def process_metrics(self, name, metrics: Dict[str, float]): + if self.header is None: + if len(metrics) > 0: + self.header = [key for key in metrics.keys()] + + if len(metrics) > 0: + row = [str(item) for item in metrics.values()] + else: + row = None + + if self.header is not None and row is not None: + assert len(self.header) == len(row) + + self.gather_results[name] = row + + @master_only + def summarize(self): + csv_file = osp.join(self.out_dir, self.summary_file) + txt_file = osp.join(self.out_dir, + self.summary_file.replace('.csv', '.txt')) + table = [] + header = ['Architecture'] + self.header + table.append(header) + for key, row in self.gather_results.items(): + if row is None: + row = ['-'] * len(header) + table.append([key] + row) + # output to screean + print(tabulate.tabulate(table, headers='firstrow')) + # output to txt file + with open(txt_file, 'w', encoding='utf-8') as f: + f.write(tabulate.tabulate(table, headers='firstrow')) + + # output to csv file + with open(csv_file, 'w', encoding='utf-8') as f: + f.write('\n'.join([','.join(row) for row in table]) + '\n') + + print_log( + f'Summary results have been saved to {csv_file}.', + logger='current', + level=logging.INFO) + + def after_run(self, runner) -> None: + # save into a csv file + if self.out_dir is None: + print_log( + '`SummarizeHook.out_dir` is not specified, cannot save ' + 'the summary file.', + logger='current', + level=logging.WARNING) + elif self.header is None: + print_log( + 'No metrics have been gathered from the runner. ' + 'Cannot save the summary file.', + logger='current', + level=logging.WARNING) + else: + self.summarize() diff --git a/lqit/common/utils/lark_manager.py b/lqit/common/utils/lark_manager.py index d9b2241..19d0368 100644 --- a/lqit/common/utils/lark_manager.py +++ b/lqit/common/utils/lark_manager.py @@ -195,6 +195,7 @@ def __init__(self) -> None: self.cfg_file = None self.task_type = None self.url = None + self.ckpt_path = None def monitor_exception(self) -> None: """Catch and format exception information, send alert message to @@ -267,6 +268,7 @@ def start_monitor(self, content = f"{self.user_name}'s {self.task_type} task has started!\n" \ f'Config file: {self.cfg_file}\n' if ckpt_path is not None: + self.ckpt_path = ckpt_path content += f'Checkpoint file: {ckpt_path}' rank = get_rank() if rank == '0' or rank == 0 or rank is None: @@ -283,6 +285,8 @@ def stop_monitor(self) -> None: content = f"{self.user_name}'s {self.task_type} task completed!\n" \ f'Config file: {self.cfg_file}\n' + if self.ckpt_path is not None: + content += f'Checkpoint file: {self.ckpt_path}\n' if os.getenv('LAST_METRIC_RESULTS') is not None: metric_content = os.getenv('LAST_METRIC_RESULTS') content += metric_content diff --git a/tools/test.py b/tools/test.py index 3d45b60..071eaf5 100644 --- a/tools/test.py +++ b/tools/test.py @@ -108,8 +108,6 @@ def main(args): # process debug mode if args.debug is True if args.debug: - # force set args.lark = False - args.lark = False # set necessary params for debug mode cfg = process_debug_mode(cfg) @@ -149,6 +147,10 @@ def main(args): monitor_manager = None + if args.debug: + # force set args.lark = False + args.lark = False + if not args.lark: main(args) else: diff --git a/tools/train.py b/tools/train.py index 6609a90..3742d5d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -95,8 +95,6 @@ def main(args): # process debug mode if args.debug is True if args.debug: - # force set args.lark = False - args.lark = False # set necessary params for debug mode cfg = process_debug_mode(cfg) @@ -172,6 +170,10 @@ def main(args): monitor_manager = None + if args.debug: + # force set args.lark = False + args.lark = False + if args.lark: # report the running status to lark bot lark_file = args.lark_file