Skip to content

Commit

Permalink
mlp w/ glu style gating
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Jan 9, 2025
1 parent f7085ab commit c19fc43
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 11 deletions.
10 changes: 8 additions & 2 deletions src/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
import torch.nn as nn
from timm.layers import trunc_normal_, RmsNorm
from timm.layers import trunc_normal_, RmsNorm, SwiGLU

from transformer import Transformer

Expand All @@ -30,6 +30,8 @@ def __init__(
init_std=0.02,
fixed_dropout_depth=False,
norm_layer: nn.Module = RmsNorm,
act_layer: nn.Module = nn.SiLU,
mlp_layer: nn.Module = SwiGLU,
**kwargs,
):
super().__init__()
Expand All @@ -48,6 +50,8 @@ def __init__(
dpr = np.linspace(0, self.drop_path_rate, self.depth)

self.norm_layer = norm_layer
self.act_layer = act_layer
self.mlp_layer = mlp_layer

self.transformer_blocks = nn.ModuleList([
Transformer(
Expand All @@ -57,7 +61,9 @@ def __init__(
proj_drop=self.proj_drop_rate,
att_drop=self.att_drop_rate,
drop_path=self.drop_path_rate if fixed_dropout_depth and self.drop_path_rate > 0.0 else dpr[i],
norm_layer=self.norm_layer
norm_layer=self.norm_layer,
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
)
for i in range(self.depth)
])
Expand Down
10 changes: 9 additions & 1 deletion src/maskedautoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn as nn
from timm.layers import RmsNorm
from timm.layers import RmsNorm, SwiGLU

from maskedencoder import MaskedEncoder
from maskedpredictor import MaskedPredictor
Expand Down Expand Up @@ -114,6 +114,8 @@ def __init__(
init_std=0.02,
fixed_dropout_depth=False,
norm_layer: nn.Module = RmsNorm,
act_layer: nn.Module = nn.SiLU,
mlp_layer: nn.Module = SwiGLU,
use_conv_proj=False,
mask_ratio=.9,
window_mask_shape=None,
Expand Down Expand Up @@ -156,6 +158,8 @@ def __init__(

self.init_std = init_std
self.norm_layer = norm_layer
self.act_layer = act_layer
self.mlp_layer = mlp_layer

self.masked_encoder = MaskedEncoder(
input_fmt="BZYXC",
Expand All @@ -172,6 +176,8 @@ def __init__(
drop_path_rate=self.drop_path_rate,
fixed_dropout_depth=self.fixed_dropout_depth,
norm_layer=self.norm_layer,
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
init_std=self.init_std,
use_conv_proj=use_conv_proj,
cls_token=False,
Expand All @@ -194,6 +200,8 @@ def __init__(
drop_path_rate=self.drop_path_rate,
fixed_dropout_depth=self.fixed_dropout_depth,
norm_layer=self.norm_layer,
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
init_std=self.init_std,
cls_token=False,
)
Expand Down
8 changes: 7 additions & 1 deletion src/maskedencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn as nn
from timm.layers import RmsNorm
from timm.layers import RmsNorm, SwiGLU

from encoder import Encoder
from masking import apply_masks
Expand Down Expand Up @@ -90,6 +90,8 @@ def __init__(
init_std=0.02,
fixed_dropout_depth=False,
norm_layer: nn.Module = RmsNorm,
act_layer: nn.Module = nn.SiLU,
mlp_layer: nn.Module = SwiGLU,
use_conv_proj=False,
**kwargs,
):
Expand Down Expand Up @@ -122,6 +124,8 @@ def __init__(

self.init_std = init_std
self.norm_layer = norm_layer
self.act_layer = act_layer
self.mlp_layer = mlp_layer
self.norm = norm_layer(self.embed_dim) if norm_layer is not None else nn.Identity()

if use_conv_proj:
Expand Down Expand Up @@ -161,6 +165,8 @@ def __init__(
drop_path_rate=self.drop_path_rate,
fixed_dropout_depth=self.fixed_dropout_depth,
norm_layer=self.norm_layer,
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
init_std=self.init_std
)

Expand Down
8 changes: 7 additions & 1 deletion src/maskedpredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.nn as nn
from timm.layers import RmsNorm
from timm.layers import RmsNorm, SwiGLU

from encoder import Encoder
from patch_embeddings import PosEmbedding
Expand Down Expand Up @@ -91,6 +91,8 @@ def __init__(
init_std=0.02,
fixed_dropout_depth=False,
norm_layer: nn.Module = RmsNorm,
act_layer: nn.Module = nn.SiLU,
mlp_layer: nn.Module = SwiGLU,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -124,6 +126,8 @@ def __init__(

self.init_std = init_std
self.norm_layer = norm_layer
self.act_layer = act_layer
self.mlp_layer = mlp_layer
self.norm = norm_layer(self.embed_dim) if norm_layer is not None else nn.Identity()

self.token_param = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
Expand Down Expand Up @@ -157,6 +161,8 @@ def __init__(
drop_path_rate=self.drop_path_rate,
fixed_dropout_depth=self.fixed_dropout_depth,
norm_layer=self.norm_layer,
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
init_std=self.init_std
)

Expand Down
9 changes: 5 additions & 4 deletions src/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import Mlp, DropPath
from timm.layers import SwiGLU, DropPath

logging.basicConfig(
stream=sys.stdout,
Expand Down Expand Up @@ -84,6 +84,8 @@ def __init__(
att_drop: float = 0.,
drop_path: float = 0.,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
act_layer: nn.Module = nn.SiLU,
mlp_layer: nn.Module = SwiGLU,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
Expand All @@ -99,12 +101,11 @@ def __init__(
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.norm2 = norm_layer(dim)
self.mlp = Mlp(
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
drop=proj_drop,
act_layer=nn.GELU,
use_conv=False,
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

Expand Down
10 changes: 8 additions & 2 deletions src/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.nn as nn
from timm.layers import AttentionPoolLatent
from timm.layers import AttentionPoolLatent, Mlp
from timm.models.vision_transformer import global_pool_nlc

from encoder import Encoder
Expand Down Expand Up @@ -93,6 +93,8 @@ def __init__(
fixed_dropout_depth=False,
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'avgmax',
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
act_layer = nn.GELU,
mlp_layer = Mlp,
use_conv_proj=False,
**kwargs,
):
Expand Down Expand Up @@ -126,6 +128,8 @@ def __init__(
self.init_std = init_std
self.global_pool = global_pool
self.norm_layer = norm_layer
self.act_layer = act_layer
self.mlp_layer = mlp_layer
self.norm = norm_layer(self.embed_dim) if norm_layer is not None else nn.Identity()

if use_conv_proj:
Expand Down Expand Up @@ -165,7 +169,9 @@ def __init__(
drop_path_rate=self.drop_path_rate,
fixed_dropout_depth=self.fixed_dropout_depth,
norm_layer=self.norm_layer,
init_std=self.init_std
act_layer=self.act_layer,
mlp_layer=self.mlp_layer,
init_std=self.init_std,
)

self.global_pool = global_pool
Expand Down

0 comments on commit c19fc43

Please sign in to comment.