diff --git a/chainercv/links/__init__.py b/chainercv/links/__init__.py index 17eddff0ac..8161c76d29 100644 --- a/chainercv/links/__init__.py +++ b/chainercv/links/__init__.py @@ -1,3 +1,4 @@ +from chainercv.links.connection.bn_activ_conv_2d import BNActivConv2D # NOQA from chainercv.links.connection.conv_2d_activ import Conv2DActiv # NOQA from chainercv.links.connection.conv_2d_bn_activ import Conv2DBNActiv # NOQA diff --git a/chainercv/links/connection/__init__.py b/chainercv/links/connection/__init__.py index a2b366ee59..e3b8c6a838 100644 --- a/chainercv/links/connection/__init__.py +++ b/chainercv/links/connection/__init__.py @@ -1,2 +1,3 @@ +from chainercv.links.connection.bn_activ_conv_2d import BNActivConv2D # NOQA from chainercv.links.connection.conv_2d_activ import Conv2DActiv # NOQA from chainercv.links.connection.conv_2d_bn_activ import Conv2DBNActiv # NOQA diff --git a/chainercv/links/connection/bn_activ_conv_2d.py b/chainercv/links/connection/bn_activ_conv_2d.py new file mode 100644 index 0000000000..ef895a0862 --- /dev/null +++ b/chainercv/links/connection/bn_activ_conv_2d.py @@ -0,0 +1,65 @@ +import chainer +from chainer.functions import relu +from chainer.links import BatchNormalization +from chainer.links import Convolution2D + + +class BNActivConv2D(chainer.Chain): + """Batch Normalization --> Activation --> Convolution2D + + This is a chain that sequentially aplies a batch normalization, + an activation and a two-dimensional convolution. + + The arguments are the same as that of + :class:`chainer.links.Convolution2D` + except for :obj:`activ` and :obj:`bn_kwargs`. + Note that the default value for the :obj:`nobias` + is changed to :obj:`True`. + + Unlike :class:`chainer.links.Convolution2D`, this class requires + :obj:`in_channels` defined explicitly. + + >>> l = BNActivConv2D(5, 10, 3) + + Args: + in_channels (int): The number of channels of input arrays. + This needs to be explicitly defined. + out_channels (int): The number of channels of output arrays. + ksize (int or pair of ints): Size of filters (a.k.a. kernels). + :obj:`ksize=k` and :obj:`ksize=(k, k)` are equivalent. + stride (int or pair of ints): Stride of filter applications. + :obj:`stride=s` and :obj:`stride=(s, s)` are equivalent. + pad (int or pair of ints): Spatial padding width for input arrays. + :obj:`pad=p` and :obj:`pad=(p, p)` are equivalent. + nobias (bool): If :obj:`True`, + then this link does not use the bias term. + initialW (4-D array): Initial weight value. If :obj:`None`, the default + initializer is used. + May also be a callable that takes :obj:`numpy.ndarray` or + :obj:`cupy.ndarray` and edits its value. + initial_bias (1-D array): Initial bias value. If :obj:`None`, the bias + is set to 0. + May also be a callable that takes :obj:`numpy.ndarray` or + :obj:`cupy.ndarray` and edits its value. + activ (callable): An activation function. The default value is + :func:`chainer.functions.relu`. + bn_kwargs (dict): Keyword arguments passed to initialize + :class:`chainer.links.BatchNormalization`. + + """ + + def __init__(self, in_channels, out_channels, ksize=None, + stride=1, pad=0, nobias=True, initialW=None, + initial_bias=None, activ=relu, bn_kwargs=dict()): + self.activ = activ + super(BNActivConv2D, self).__init__() + with self.init_scope(): + self.bn = BatchNormalization(in_channels, **bn_kwargs) + self.conv = Convolution2D( + in_channels, out_channels, ksize, stride, pad, + nobias, initialW, initial_bias) + + def __call__(self, x): + h = self.bn(x) + h = self.activ(h) + return self.conv(h) diff --git a/docs/source/reference/links/connection.rst b/docs/source/reference/links/connection.rst index f662e8855a..7f7d304ea1 100644 --- a/docs/source/reference/links/connection.rst +++ b/docs/source/reference/links/connection.rst @@ -4,6 +4,10 @@ Connection .. module:: chainercv.links.connection +BNActivConv2D +------------- +.. autoclass:: BNActivConv2D + Conv2DActiv ----------- .. autoclass:: Conv2DActiv diff --git a/tests/links_tests/connection_tests/test_bn_activ_conv_2d.py b/tests/links_tests/connection_tests/test_bn_activ_conv_2d.py new file mode 100644 index 0000000000..398e6fab76 --- /dev/null +++ b/tests/links_tests/connection_tests/test_bn_activ_conv_2d.py @@ -0,0 +1,94 @@ +import unittest + +import numpy as np + +import chainer +from chainer import cuda +from chainer.functions import relu +from chainer import testing +from chainer.testing import attr + +from chainercv.links import BNActivConv2D + + +def _add_one(x): + return x + 1 + + +@testing.parameterize(*testing.product({ + 'activ': ['relu', 'add_one'], +})) +class TestBNActivConv2D(unittest.TestCase): + + in_channels = 1 + out_channels = 1 + ksize = 3 + stride = 1 + pad = 1 + + def setUp(self): + if self.activ == 'relu': + activ = relu + elif self.activ == 'add_one': + activ = _add_one + self.x = np.random.uniform( + -1, 1, (5, self.in_channels, 5, 5)).astype(np.float32) + self.gy = np.random.uniform( + -1, 1, (5, self.out_channels, 5, 5)).astype(np.float32) + + # Convolution is the identity function. + initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], + dtype=np.float32).reshape((1, 1, 3, 3)) + bn_kwargs = {'decay': 0.8} + initial_bias = 0 + self.l = BNActivConv2D( + self.in_channels, self.out_channels, self.ksize, + self.stride, self.pad, + initialW=initialW, initial_bias=initial_bias, + activ=activ, bn_kwargs=bn_kwargs) + + def check_forward(self, x_data): + x = chainer.Variable(x_data) + # Make the batch normalization to be the identity function. + self.l.bn.avg_var[:] = 1 + self.l.bn.avg_mean[:] = 0 + with chainer.using_config('train', False): + y = self.l(x) + self.assertIsInstance(y, chainer.Variable) + self.assertIsInstance(y.data, self.l.xp.ndarray) + + if self.activ == 'relu': + np.testing.assert_almost_equal( + cuda.to_cpu(y.data), np.maximum(cuda.to_cpu(x_data), 0), + decimal=4 + ) + elif self.activ == 'add_one': + np.testing.assert_almost_equal( + cuda.to_cpu(y.data), cuda.to_cpu(x_data) + 1, + decimal=4 + ) + + def test_forward_cpu(self): + self.check_forward(self.x) + + @attr.gpu + def test_forward_gpu(self): + self.l.to_gpu() + self.check_forward(cuda.to_gpu(self.x)) + + def check_backward(self, x_data, y_grad): + x = chainer.Variable(x_data) + y = self.l(x) + y.grad = y_grad + y.backward() + + def test_backward_cpu(self): + self.check_backward(self.x, self.gy) + + @attr.gpu + def test_backward_gpu(self): + self.l.to_gpu() + self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy)) + + +testing.run_module(__name__, __file__)