From 4d8ee157822ba8ef9706a77bcdd248628a674d21 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 15 Jul 2023 21:47:20 +0800 Subject: [PATCH] rename fast_conv_bn_eval to efficient_conv_bn_eval --- mmcv/cnn/bricks/conv_module.py | 44 ++++++++++++++++-------------- tests/test_cnn/test_conv_module.py | 37 +++++++++++++------------ 2 files changed, 44 insertions(+), 37 deletions(-) diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index e9fcfd67bd..6e05550798 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -15,8 +15,9 @@ from .padding import build_padding_layer -def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd, - x: torch.Tensor): +def efficient_conv_bn_eval_forward(bn: _BatchNorm, + conv: nn.modules.conv._ConvNd, + x: torch.Tensor): """ Implementation based on https://arxiv.org/abs/2305.11624 "Tune-Mode ConvBN Blocks For Efficient Transfer Learning" @@ -115,9 +116,9 @@ class ConvModule(nn.Module): sequence of "conv", "norm" and "act". Common examples are ("conv", "norm", "act") and ("act", "conv", "norm"). Default: ('conv', 'norm', 'act'). - fast_conv_bn_eval (bool): Whether use fast conv when the consecutive - bn is in eval mode (either training or testing), as proposed in - https://arxiv.org/abs/2305.11624 . Default: False. + efficient_conv_bn_eval (bool): Whether use efficient conv when the + consecutive bn is in eval mode (either training or testing), as + proposed in https://arxiv.org/abs/2305.11624 . Default: `False`. """ _abbr_ = 'conv_block' @@ -138,7 +139,7 @@ def __init__(self, with_spectral_norm: bool = False, padding_mode: str = 'zeros', order: tuple = ('conv', 'norm', 'act'), - fast_conv_bn_eval: bool = False): + efficient_conv_bn_eval: bool = False): super().__init__() assert conv_cfg is None or isinstance(conv_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict) @@ -209,7 +210,7 @@ def __init__(self, else: self.norm_name = None # type: ignore - self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval) + self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval) # build activation layer if self.with_activation: @@ -263,15 +264,16 @@ def forward(self, if self.with_explicit_padding: x = self.padding_layer(x) # if the next operation is norm and we have a norm layer in - # eval mode and we have enabled fast_conv_bn_eval for the conv - # operator, then activate the optimized forward and skip the - # next norm operator since it has been fused + # eval mode and we have enabled `efficient_conv_bn_eval` for + # the conv operator, then activate the optimized forward and + # skip the next norm operator since it has been fused if layer_index + 1 < len(self.order) and \ self.order[layer_index + 1] == 'norm' and norm and \ self.with_norm and not self.norm.training and \ - self.fast_conv_bn_eval_forward is not None: - self.conv.forward = partial(self.fast_conv_bn_eval_forward, - self.norm, self.conv) + self.efficient_conv_bn_eval_forward is not None: + self.conv.forward = partial( + self.efficient_conv_bn_eval_forward, self.norm, + self.conv) layer_index += 1 x = self.conv(x) del self.conv.forward @@ -284,20 +286,22 @@ def forward(self, layer_index += 1 return x - def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True): - # fast_conv_bn_eval works for conv + bn + def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True): + # efficient_conv_bn_eval works for conv + bn # with `track_running_stats` option - if fast_conv_bn_eval and self.norm \ + if efficient_conv_bn_eval and self.norm \ and isinstance(self.norm, _BatchNorm) \ and self.norm.track_running_stats: - self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward + # this is to bypass the flake8 check for 79 chars in one line + enabled = efficient_conv_bn_eval_forward + self.efficient_conv_bn_eval_forward = enabled else: - self.fast_conv_bn_eval_forward = None # type: ignore + self.efficient_conv_bn_eval_forward = None # type: ignore @staticmethod def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd, bn: torch.nn.modules.batchnorm._BatchNorm, - fast_conv_bn_eval=True) -> 'ConvModule': + efficient_conv_bn_eval=True) -> 'ConvModule': """Create a ConvModule from a conv and a bn module.""" self = ConvModule.__new__(ConvModule) super(ConvModule, self).__init__() @@ -331,6 +335,6 @@ def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd, self.norm_name, norm = 'bn', bn self.add_module(self.norm_name, norm) - self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval) + self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval) return self diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index 55a17607d1..09cb5f4ace 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -75,27 +75,30 @@ def test_conv_module(): output = conv(x) assert output.shape == (1, 8, 255, 255) - # conv + norm with fast mode - fast_conv = ConvModule( - 3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval() + # conv + norm with efficient mode + efficient_conv = ConvModule( + 3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval() plain_conv = ConvModule( - 3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval() - for fast_param, plain_param in zip(fast_conv.state_dict().values(), - plain_conv.state_dict().values()): - plain_param.copy_(fast_param) - - fast_mode_output = fast_conv(x) + 3, 8, 2, norm_cfg=dict(type='BN'), + efficient_conv_bn_eval=False).eval() + for efficient_param, plain_param in zip( + efficient_conv.state_dict().values(), + plain_conv.state_dict().values()): + plain_param.copy_(efficient_param) + + efficient_mode_output = efficient_conv(x) plain_mode_output = plain_conv(x) - assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5) + assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5) - # `conv` attribute can be dynamically modified in fast mode - fast_conv = ConvModule( - 3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval() + # `conv` attribute can be dynamically modified in efficient mode + efficient_conv = ConvModule( + 3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval() new_conv = nn.Conv2d(3, 8, 2).eval() - fast_conv.conv = new_conv - fast_mode_output = fast_conv(x) - plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x))) - assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5) + efficient_conv.conv = new_conv + efficient_mode_output = efficient_conv(x) + plain_mode_output = efficient_conv.activate( + efficient_conv.norm(new_conv(x))) + assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5) # conv + act conv = ConvModule(3, 8, 2)