Skip to content

Commit

Permalink
Update ConvNeXt
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Aug 20, 2023
1 parent 8279845 commit a838810
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions celldetection/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,26 @@ def __init__(
class CNBlock(nn.Module): # ported from torchvision.models.convnext to support n-dimensions and add more features
def __init__(self, in_channels, out_channels=None, layer_scale: float = 1e-6, stochastic_depth_prob: float = 0,
norm_layer: Optional[Callable[..., nn.Module]] = None, activation='gelu', stride: int = 1,
identity_norm_layer=None, nd: int = 2) -> None:
identity_norm_layer=None, nd: int = 2, conv_kwargs=None) -> None:
super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
if conv_kwargs is None:
conv_kwargs = {}
Conv = lookup_nn('Conv2d', nd=nd, call=False)
out_channels = in_channels if out_channels is None else out_channels
self.identity = None
if in_channels != out_channels or stride != 1:
if identity_norm_layer is None:
identity_norm_layer = [LayerNorm1d, LayerNorm2d, LayerNorm3d][nd - 1]
self.identity = nn.Sequential( # following option (b) in He et al. (2015)
Conv(out_channels, out_channels, kernel_size=1, stride=stride, bias=False),
Conv(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
identity_norm_layer(out_channels)
)
self.block = nn.Sequential(
Conv(in_channels, out_channels, kernel_size=7, padding=3, groups=out_channels, bias=True),
Conv(in_channels, out_channels, kernel_size=conv_kwargs.pop('kernel_size', 7),
padding=conv_kwargs.pop('padding', 3), groups=conv_kwargs.pop('groups', out_channels),
bias=conv_kwargs.pop('bias', True), **conv_kwargs),
Permute(list(channels_last_permute(nd))),
norm_layer(out_channels),
nn.Linear(in_features=out_channels, out_features=4 * out_channels, bias=True),
Expand Down

0 comments on commit a838810

Please sign in to comment.