Skip to content

Commit

Permalink
[Feature] Support SummarizeHook (#34)
Browse files Browse the repository at this point in the history
* update debug mode logic

* [Feature] Support SummarizeHook
  • Loading branch information
BIGWangYuDong committed Nov 21, 2024
1 parent d6f62ab commit 4582250
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 7 deletions.
3 changes: 2 additions & 1 deletion configs/detection/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='DetVisualizationHook'))
visualization=dict(type='DetVisualizationHook'),
summarizer=dict(type='lqit.SummarizeHook'))

randomness = dict(seed=None, deterministic=False)

Expand Down
3 changes: 2 additions & 1 deletion lqit/common/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .lark_hook import LarkHook
from .summarize_hook import SummarizeHook

__all__ = ['LarkHook']
__all__ = ['LarkHook', 'SummarizeHook']
1 change: 0 additions & 1 deletion lqit/common/engine/hooks/lark_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def before_train(self, runner) -> None:
def before_test(self, runner) -> None:
if self.silent:
return
# TODO: Check
title = 'Task Initiation Report'
content = f"{self.user_name}'s task has started testing!\n" \
f'Config file: {self.cfg_file}\n' \
Expand Down
129 changes: 129 additions & 0 deletions lqit/common/engine/hooks/summarize_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging
import os.path as osp
from pathlib import Path
from typing import Dict, Optional, Union

import tabulate
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.logging import print_log

from lqit.registry import HOOKS


@HOOKS.register_module()
class SummarizeHook(Hook):
"""Summarize Hook, saving the metrics into a csv file.
Args:
summary_file (str): The name of the summary file.
Defaults to 'gather_results.csv'.
out_dir (str): The output directory. If not specified, it will be set
to the log directory of the runner. Defaults to None.
by_epoch (bool): Whether to save the metrics by epoch or by iteration.
Defaults to True.
"""
priority = 'VERY_LOW'

def __init__(self,
summary_file: str = 'gather_results.csv',
out_dir: Optional[Union[str, Path]] = None,
by_epoch: bool = True):
if not summary_file.endswith('.csv'):
summary_file += '.csv'

if out_dir is not None and not isinstance(out_dir, (str, Path)):
raise TypeError('out_dir must be a str or Path object')
self.out_dir = out_dir

if by_epoch:
self.metric_tmpl = 'epoch_{}'
else:
self.metric_tmpl = 'iter_{}'

self.summary_file = summary_file
self.by_epoch = by_epoch
self.header = None
self.gather_results = dict()

def before_run(self, runner) -> None:
"""Set the output directory to the log directory of the runner if
`out_dir` is not specified."""
if self.out_dir is None:
self.out_dir = runner.log_dir

def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
if self.by_epoch:
name = self.metric_tmpl.format(runner.epoch)
else:
name = self.metric_tmpl.format(runner.iter)
self.process_metrics(name, metrics)

def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
# name set as the checkpoint name
ckpt_path = runner._load_from
name = osp.basename(ckpt_path).split('.')[0]
self.process_metrics(name, metrics)

def process_metrics(self, name, metrics: Dict[str, float]):
if self.header is None:
if len(metrics) > 0:
self.header = [key for key in metrics.keys()]

if len(metrics) > 0:
row = [str(item) for item in metrics.values()]
else:
row = None

if self.header is not None and row is not None:
assert len(self.header) == len(row)

self.gather_results[name] = row

@master_only
def summarize(self):
csv_file = osp.join(self.out_dir, self.summary_file)
txt_file = osp.join(self.out_dir,
self.summary_file.replace('.csv', '.txt'))
table = []
header = ['Architecture'] + self.header
table.append(header)
for key, row in self.gather_results.items():
if row is None:
row = ['-'] * len(header)
table.append([key] + row)
# output to screean
print(tabulate.tabulate(table, headers='firstrow'))
# output to txt file
with open(txt_file, 'w', encoding='utf-8') as f:
f.write(tabulate.tabulate(table, headers='firstrow'))

# output to csv file
with open(csv_file, 'w', encoding='utf-8') as f:
f.write('\n'.join([','.join(row) for row in table]) + '\n')

print_log(
f'Summary results have been saved to {csv_file}.',
logger='current',
level=logging.INFO)

def after_run(self, runner) -> None:
# save into a csv file
if self.out_dir is None:
print_log(
'`SummarizeHook.out_dir` is not specified, cannot save '
'the summary file.',
logger='current',
level=logging.WARNING)
elif self.header is None:
print_log(
'No metrics have been gathered from the runner. '
'Cannot save the summary file.',
logger='current',
level=logging.WARNING)
else:
self.summarize()
4 changes: 4 additions & 0 deletions lqit/common/utils/lark_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(self) -> None:
self.cfg_file = None
self.task_type = None
self.url = None
self.ckpt_path = None

def monitor_exception(self) -> None:
"""Catch and format exception information, send alert message to
Expand Down Expand Up @@ -267,6 +268,7 @@ def start_monitor(self,
content = f"{self.user_name}'s {self.task_type} task has started!\n" \
f'Config file: {self.cfg_file}\n'
if ckpt_path is not None:
self.ckpt_path = ckpt_path
content += f'Checkpoint file: {ckpt_path}'
rank = get_rank()
if rank == '0' or rank == 0 or rank is None:
Expand All @@ -283,6 +285,8 @@ def stop_monitor(self) -> None:

content = f"{self.user_name}'s {self.task_type} task completed!\n" \
f'Config file: {self.cfg_file}\n'
if self.ckpt_path is not None:
content += f'Checkpoint file: {self.ckpt_path}\n'
if os.getenv('LAST_METRIC_RESULTS') is not None:
metric_content = os.getenv('LAST_METRIC_RESULTS')
content += metric_content
Expand Down
6 changes: 4 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ def main(args):

# process debug mode if args.debug is True
if args.debug:
# force set args.lark = False
args.lark = False
# set necessary params for debug mode
cfg = process_debug_mode(cfg)

Expand Down Expand Up @@ -149,6 +147,10 @@ def main(args):

monitor_manager = None

if args.debug:
# force set args.lark = False
args.lark = False

if not args.lark:
main(args)
else:
Expand Down
6 changes: 4 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def main(args):

# process debug mode if args.debug is True
if args.debug:
# force set args.lark = False
args.lark = False
# set necessary params for debug mode
cfg = process_debug_mode(cfg)

Expand Down Expand Up @@ -172,6 +170,10 @@ def main(args):

monitor_manager = None

if args.debug:
# force set args.lark = False
args.lark = False

if args.lark:
# report the running status to lark bot
lark_file = args.lark_file
Expand Down

0 comments on commit 4582250

Please sign in to comment.