-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlpmixer.py
144 lines (123 loc) · 4.94 KB
/
mlpmixer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from torch import nn
import torch
import math
import matplotlib.pyplot as plt
import os
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size, patch_size, in_chans=3,
embed_dim=768):
super(PatchEmbed, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size * self.grid_size
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
"""
:param x: image tensor of shape [batch, channels, img_size, img_size]
:return out: [batch. num_patches, embed_dim]
"""
_, _, H, W = x.shape
assert H == self.img_size, f"Input image height ({H}) doesn't match model ({self.img_size})."
assert W == self.img_size, f"Input image width ({W}) doesn't match model ({self.img_size})."
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks """
def __init__(
self,
in_features,
hidden_features,
act_layer=nn.GELU,
drop=0.,
):
super(Mlp, self).__init__()
out_features = in_features
hidden_features = hidden_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class MixerBlock(nn.Module):
""" Residual Block w/ token mixing and channel MLPs
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
def __init__(
self, dim, seq_len, mlp_ratio=(0.5, 4.0),
activation='gelu', drop=0., drop_path=0.):
super(MixerBlock, self).__init__()
act_layer = {'gelu': nn.GELU, 'relu': nn.ReLU}[activation]
tokens_dim, channels_dim = int(mlp_ratio[0] * dim), int(mlp_ratio[1] * dim)
self.norm1 = nn.LayerNorm(dim, eps=1e-6) # norm1 used with mlp_tokens
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.norm2 = nn.LayerNorm(dim, eps=1e-6) # norm2 used with mlp_channels
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
out = self.norm1(x)
out = torch.transpose(out, 1, 2) # 4*768
out = self.mlp_tokens(out) # 4*768
token_mix_out = x + torch.transpose(out, 1, 2)
out = self.norm2(token_mix_out)
out = token_mix_out + self.mlp_channels(out)
return out
class MLPMixer(nn.Module):
def __init__(self, num_classes, img_size, patch_size, embed_dim, num_blocks,
drop_rate=0., activation='gelu'):
super(MLPMixer, self).__init__()
self.patchemb = PatchEmbed(img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim)
self.blocks = nn.Sequential(*[
MixerBlock(
dim=embed_dim, seq_len=self.patchemb.num_patches,
activation=activation, drop=drop_rate)
for _ in range(num_blocks)])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
self.num_classes = num_classes
self.apply(self.init_weights)
def init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, images):
""" MLPMixer forward
:param images: [batch, 3, img_size, img_size]
"""
# step1: Go through the patch embedding
# step 2 Go through the mixer blocks
# step 3 go through layer norm
# step 4 Global averaging spatially
# Classification
out = self.patchemb(images)
out = self.blocks(out)
out = self.norm(out)
out = out.mean(dim=1)
out = self.head(out)
return out
def visualize(self, logdir):
""" Visualize the token mixer layer
in the desired directory """
raise NotImplementedError
if __name__ == "__main__":
model = MLPMixer(10, 32, 4, 256, 4)
print(model)
# print(model.Conv2d.weights)