Skip to content

Commit

Permalink
Few more changes regarding resnetrs architecture.
Browse files Browse the repository at this point in the history
Added squeeze and excite layer.
Added dropout after global average pooling and before fc layer.
Added conv-bn-activation block in place of maxpool.
Added api for create model, create pretrained model, list models.
Added util functions to load weights and save weights.
  • Loading branch information
nachiket273 committed Jun 30, 2021
1 parent 612af0a commit 528817f
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 20 deletions.
66 changes: 55 additions & 11 deletions model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def __init__(self, in_ch=3, stem_width=32, norm_layer=nn.BatchNorm2d,
])
self.bn1 = norm_layer(inplanes)
self.actn1 = actn(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.maxpool = nn.Sequential(*[
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=2, padding=1,
bias=False),
norm_layer(inplanes),
actn(inplace=True)
])
self.init_weights()

def init_weights(self):
Expand All @@ -49,7 +54,8 @@ class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d,
actn=nn.ReLU, downsample=None, zero_init_last_bn=True):
actn=nn.ReLU, downsample=None, seblock=True,
reduction_ratio=0.25, zero_init_last_bn=True):
super().__init__()
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
Expand All @@ -60,9 +66,14 @@ def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d,
self.conv2 = nn.Conv2d(planes, outplanes, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn2 = norm_layer(outplanes)
self.seblock = seblock
if seblock:
self.se = SEBlock(outplanes, reduction_ratio)
self.actn2 = actn(inplace=True)
self.downsample = downsample if downsample is not None \
else nn.Identity()
self.down = False
if downsample is not None:
self.downsample = downsample
self.down = True
self.init_weights(zero_init_last_bn)

def init_weights(self, zero_init_last_bn=True):
Expand All @@ -78,14 +89,16 @@ def init_weights(self, zero_init_last_bn=True):
nn.init.zeros_(self.bn2.weight)

def forward(self, x):
shortcut = self.downsample(x)
shortcut = self.downsample(x) if self.down else x

x = self.conv1(x)
x = self.bn1(x)
x = self.actn1(x)

x = self.conv2(x)
x = self.bn2(x)
if self.seblock:
x = self.se(x)
x += shortcut
x = self.actn2(x)

Expand All @@ -96,7 +109,8 @@ class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d,
actn=nn.ReLU, downsample=None, zero_init_last_bn=True):
actn=nn.ReLU, downsample=None, seblock=True,
reduction_ratio=0.25, zero_init_last_bn=True):
super().__init__()
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
Expand All @@ -110,9 +124,14 @@ def __init__(self, inplanes, planes, stride=1, norm_layer=nn.BatchNorm2d,

self.conv3 = nn.Conv2d(planes, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.seblock = seblock
if seblock:
self.se = SEBlock(outplanes, reduction_ratio)
self.actn3 = actn(inplace=True)
self.downsample = downsample if downsample is not None \
else nn.Identity()
self.down = False
if downsample is not None:
self.downsample = downsample
self.down = True
self.init_weights(zero_init_last_bn)

def init_weights(self, zero_init_last_bn=True):
Expand All @@ -128,7 +147,7 @@ def init_weights(self, zero_init_last_bn=True):
nn.init.zeros_(self.bn2.weight)

def forward(self, x):
shortcut = self.downsample(x)
shortcut = self.downsample(x) if self.down else x

x = self.conv1(x)
x = self.bn1(x)
Expand All @@ -140,6 +159,8 @@ def forward(self, x):

x = self.conv3(x)
x = self.bn3(x)
if self.seblock:
x = self.se(x)
x += shortcut
x = self.actn3(x)

Expand All @@ -150,9 +171,13 @@ class Downsample(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=1, stride=1,
norm_layer=nn.BatchNorm2d):
super().__init__()
if stride == 1:
avgpool = nn.Identity()
else:
avgpool = nn.AvgPool2d(2, stride=stride, ceil_mode=True,
count_include_pad=False)
self.downsample = nn.Sequential(*[
nn.AvgPool2d(2, stride=stride, ceil_mode=True,
count_include_pad=False),
avgpool,
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=1,
padding=0, bias=False),
norm_layer(out_ch)
Expand All @@ -170,3 +195,22 @@ def init_weights(self):

def forward(self, x):
return self.downsample(x)


class SEBlock(nn.Module):
def __init__(self, channels, reduction_ratio=0.25):
super().__init__()
reduced_channels = int(channels * reduction_ratio)
self.conv1 = nn.Conv2d(channels, reduced_channels, kernel_size=1)
self.actn = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(reduced_channels, channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
orig = x
x = x.mean((2, 3), keepdim=True)
x = self.conv1(x)
x = self.actn(x)
x = self.conv2(x)
x = self.sigmoid(x)
return orig * x
143 changes: 134 additions & 9 deletions model/resnetrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,64 @@
(https://arxiv.org/pdf/2103.07579.pdf)
"""
from functools import partial
import torch.nn as nn
import torch.nn.functional as F

from base import StemBlock, BasicBlock, Bottleneck, Downsample
from util import get_pretrained_weights


class ResnetRS(nn.Module):
PRETRAINED_MODELS = [
'resnetrs50',
'resnetrs101',
'resnetrs152',
'resnetrs200'
]

PRETRAINED_URLS = {
'resnetrs50': '',
'resnetrs101': '',
'resnetrs152': '',
'resnetrs200': '',
}

DEFAULT_CFG = {
'in_ch': 3,
'num_classes': 1000,
'stem_width': 64,
'down_kernel_size': 1,
'actn': partial(nn.ReLU, inplace=True),
'norm_layer': nn.BatchNorm2d,
'zero_init_last_bn': True,
'seblock': True,
'reduction_ratio': 0.25,
'dropout_ratio': 0.,
'conv1': 'conv1',
'classifier': 'fc'
}


class Resnet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_ch=3, stem_width=64,
down_kernel_size=1, actn=nn.ReLU, norm_layer=nn.BatchNorm2d,
seblock=True, reduction_ratio=0.25, dropout_ratio=0.,
zero_init_last_bn=True):
super().__init__()
self.num_classes = num_classes
self.norm_layer = norm_layer
self.actn = actn
self.dropout_ratio = float(dropout_ratio)
self.zero_init_last_bn = zero_init_last_bn
self.conv1 = StemBlock(in_ch, stem_width, norm_layer, actn)
channels = [64, 128, 256, 512]
self.make_layers(block, layers, channels, stem_width*2,
down_kernel_size)
self.avg_pool = nn.Sequential(*[
nn.AdaptiveAvgPool2d(1),
nn.Flatten()
])
down_kernel_size, seblock, reduction_ratio)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(512 * block.expansion, num_classes, bias=True)

def make_layers(self, block, nlayers, channels, inplanes, kernel_size=1):
def make_layers(self, block, nlayers, channels, inplanes, kernel_size=1,
seblock=True, reduction_ratio=0.25):
for idx, (nlayer, channel) in enumerate(zip(nlayers, channels)):
name = "layer" + str(idx+1)
stride = 1 if idx == 0 else 2
Expand All @@ -44,19 +77,111 @@ def make_layers(self, block, nlayers, channels, inplanes, kernel_size=1):
downsample = downsample if layer_idx == 0 else None
stride = stride if layer_idx == 0 else 1
blocks.append(block(inplanes, channel, stride, self.norm_layer,
self.actn, downsample,
self.zero_init_last_bn))
self.actn, downsample, seblock,
reduction_ratio, self.zero_init_last_bn))

inplanes = channel * block.expansion

self.add_module(*(name, nn.Sequential(*blocks)))

def init_weights(self, zero_init_last_bn=True):
for _, module in self.named_modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out',
nonlinearity='relu')
if isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.actn1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avg_pool(x)
x = x.flatten(1, -1)
if self.dropout_ratio > 0.:
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = self.fc(x)
return x


class ResnetRS():
def __init__(self):
super().__init__()

@classmethod
def create_model(cls, block, layers, num_classes=1000, in_ch=3,
stem_width=64, down_kernel_size=1,
actn=partial(nn.ReLU, inplace=True),
norm_layer=nn.BatchNorm2d, seblock=True,
reduction_ratio=0.25, dropout_ratio=0.,
zero_init_last_bn=True):

return Resnet(block, layers, num_classes=num_classes, in_ch=in_ch,
stem_width=stem_width, down_kernel_size=down_kernel_size,
actn=actn, norm_layer=norm_layer, seblock=seblock,
reduction_ratio=reduction_ratio,
dropout_ratio=dropout_ratio,
zero_init_last_bn=zero_init_last_bn)

@classmethod
def list_pretrained(cls):
return PRETRAINED_MODELS

@classmethod
def _is_valid_model_name(cls, name):
name = name.strip()
name = name.lower()
return name in PRETRAINED_MODELS

@classmethod
def _get_url(cls, name):
return PRETRAINED_URLS[name]

@classmethod
def _get_default_cfg(cls):
return DEFAULT_CFG

@classmethod
def _get_cfg(cls, name):
cfg = ResnetRS._get_default_cfg()
cfg['block'] = Bottleneck
if name == 'resnetrs50':
cfg['layers'] = [3, 4, 6, 3]
elif name == 'resnetrs101':
cfg['layers'] = [3, 4, 23, 3]
elif name == 'resnetrs152':
cfg['layers'] = [3, 8, 36, 3]
elif name == 'resnetrs200':
cfg['layers'] = [3, 24, 36, 3]
return cfg

@classmethod
def create_pretrained(cls, name, in_ch=0, num_classes=0):
if not ResnetRS._is_valid_model_name(name):
raise ValueError('Available pretrained models: ' +
', '.join(PRETRAINED_MODELS))

cfg = ResnetRS._get_cfg(name)
in_ch = cfg['in_ch'] if in_ch == 0 else in_ch
num_classes = cfg['num_classes'] if num_classes == 0 else num_classes

url = ResnetRS._get_url(name)
model = Resnet(cfg['block'], cfg['layers'], num_classes=num_classes,
in_ch=in_ch, stem_width=cfg['stem_width'],
down_kernel_size=cfg['down_kernel_size'],
actn=cfg['actn'], norm_layer=cfg['norm_layer'],
seblock=cfg['seblock'],
reduction_ratio=cfg['reduction_ratio'],
zero_init_last_bn=cfg['zero_init_last_bn'])

state_dict = get_pretrained_weights(url, cfg, num_classes, in_ch,
check_hash=True)

model.load_state_dict(state_dict, strict=cfg['strict'])
return model
Loading

0 comments on commit 528817f

Please sign in to comment.