Skip to content

Commit

Permalink
rename fast_conv_bn_eval to efficient_conv_bn_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Aug 3, 2023
1 parent 86a38aa commit 4d8ee15
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 37 deletions.
44 changes: 24 additions & 20 deletions mmcv/cnn/bricks/conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from .padding import build_padding_layer


def fast_conv_bn_eval_forward(bn: _BatchNorm, conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
def efficient_conv_bn_eval_forward(bn: _BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
Expand Down Expand Up @@ -115,9 +116,9 @@ class ConvModule(nn.Module):
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
fast_conv_bn_eval (bool): Whether use fast conv when the consecutive
bn is in eval mode (either training or testing), as proposed in
https://arxiv.org/abs/2305.11624 . Default: False.
efficient_conv_bn_eval (bool): Whether use efficient conv when the
consecutive bn is in eval mode (either training or testing), as
proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
"""

_abbr_ = 'conv_block'
Expand All @@ -138,7 +139,7 @@ def __init__(self,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act'),
fast_conv_bn_eval: bool = False):
efficient_conv_bn_eval: bool = False):
super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict)
Expand Down Expand Up @@ -209,7 +210,7 @@ def __init__(self,
else:
self.norm_name = None # type: ignore

self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)

# build activation layer
if self.with_activation:
Expand Down Expand Up @@ -263,15 +264,16 @@ def forward(self,
if self.with_explicit_padding:
x = self.padding_layer(x)
# if the next operation is norm and we have a norm layer in
# eval mode and we have enabled fast_conv_bn_eval for the conv
# operator, then activate the optimized forward and skip the
# next norm operator since it has been fused
# eval mode and we have enabled `efficient_conv_bn_eval` for
# the conv operator, then activate the optimized forward and
# skip the next norm operator since it has been fused
if layer_index + 1 < len(self.order) and \
self.order[layer_index + 1] == 'norm' and norm and \
self.with_norm and not self.norm.training and \
self.fast_conv_bn_eval_forward is not None:
self.conv.forward = partial(self.fast_conv_bn_eval_forward,
self.norm, self.conv)
self.efficient_conv_bn_eval_forward is not None:
self.conv.forward = partial(
self.efficient_conv_bn_eval_forward, self.norm,
self.conv)
layer_index += 1
x = self.conv(x)
del self.conv.forward
Expand All @@ -284,20 +286,22 @@ def forward(self,
layer_index += 1
return x

def turn_on_fast_conv_bn_eval(self, fast_conv_bn_eval=True):
# fast_conv_bn_eval works for conv + bn
def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
# efficient_conv_bn_eval works for conv + bn
# with `track_running_stats` option
if fast_conv_bn_eval and self.norm \
if efficient_conv_bn_eval and self.norm \
and isinstance(self.norm, _BatchNorm) \
and self.norm.track_running_stats:
self.fast_conv_bn_eval_forward = fast_conv_bn_eval_forward
# this is to bypass the flake8 check for 79 chars in one line
enabled = efficient_conv_bn_eval_forward
self.efficient_conv_bn_eval_forward = enabled
else:
self.fast_conv_bn_eval_forward = None # type: ignore
self.efficient_conv_bn_eval_forward = None # type: ignore

@staticmethod
def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
bn: torch.nn.modules.batchnorm._BatchNorm,
fast_conv_bn_eval=True) -> 'ConvModule':
efficient_conv_bn_eval=True) -> 'ConvModule':
"""Create a ConvModule from a conv and a bn module."""
self = ConvModule.__new__(ConvModule)
super(ConvModule, self).__init__()
Expand Down Expand Up @@ -331,6 +335,6 @@ def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
self.norm_name, norm = 'bn', bn
self.add_module(self.norm_name, norm)

self.turn_on_fast_conv_bn_eval(fast_conv_bn_eval)
self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)

return self
37 changes: 20 additions & 17 deletions tests/test_cnn/test_conv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,30 @@ def test_conv_module():
output = conv(x)
assert output.shape == (1, 8, 255, 255)

# conv + norm with fast mode
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
# conv + norm with efficient mode
efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
plain_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=False).eval()
for fast_param, plain_param in zip(fast_conv.state_dict().values(),
plain_conv.state_dict().values()):
plain_param.copy_(fast_param)

fast_mode_output = fast_conv(x)
3, 8, 2, norm_cfg=dict(type='BN'),
efficient_conv_bn_eval=False).eval()
for efficient_param, plain_param in zip(
efficient_conv.state_dict().values(),
plain_conv.state_dict().values()):
plain_param.copy_(efficient_param)

efficient_mode_output = efficient_conv(x)
plain_mode_output = plain_conv(x)
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)

# `conv` attribute can be dynamically modified in fast mode
fast_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), fast_conv_bn_eval=True).eval()
# `conv` attribute can be dynamically modified in efficient mode
efficient_conv = ConvModule(
3, 8, 2, norm_cfg=dict(type='BN'), efficient_conv_bn_eval=True).eval()
new_conv = nn.Conv2d(3, 8, 2).eval()
fast_conv.conv = new_conv
fast_mode_output = fast_conv(x)
plain_mode_output = fast_conv.activate(fast_conv.norm(new_conv(x)))
assert torch.allclose(fast_mode_output, plain_mode_output, atol=1e-5)
efficient_conv.conv = new_conv
efficient_mode_output = efficient_conv(x)
plain_mode_output = efficient_conv.activate(
efficient_conv.norm(new_conv(x)))
assert torch.allclose(efficient_mode_output, plain_mode_output, atol=1e-5)

# conv + act
conv = ConvModule(3, 8, 2)
Expand Down

0 comments on commit 4d8ee15

Please sign in to comment.