Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding layernorm and no weight decay for AB #96

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions mambular/arch_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=True,
):
super().__init__()

Expand All @@ -66,6 +69,9 @@ def __init__(
activation,
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AB_weight_decay,
AB_layer_norm,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -105,6 +111,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()

Expand Down Expand Up @@ -149,8 +158,11 @@ def __init__(
activation=activation,
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
layer_norm_eps=layer_norm_eps,
AB_weight_decay=AB_weight_decay,
AB_layer_norm=AB_layer_norm,
)
self.norm = norm(d_model)
self.norm = norm(d_model, eps=layer_norm_eps)

def forward(self, x):
output = self.layers(self.norm(x)) + x
Expand Down Expand Up @@ -189,6 +201,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
Expand Down Expand Up @@ -239,6 +254,7 @@ def __init__(
elif dt_init == "random":
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
if self.bidirectional:

nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
Expand All @@ -262,17 +278,35 @@ def __init__(

A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log_fwd = nn.Parameter(torch.log(A))
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))

if self.bidirectional:
self.A_log_bwd = nn.Parameter(torch.log(A))
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_fwd._no_weight_decay = True
self.D_fwd._no_weight_decay = True

self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
if self.bidirectional:
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_bwd._no_weight_decay = True
self.D_bwd._no_weight_decay = True

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state

if AB_layer_norm:
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
else:
self.dt_layernorm = None
self.B_layernorm = None
self.C_layernorm = None

def forward(self, x):
_, L, _ = x.shape

Expand Down Expand Up @@ -316,6 +350,15 @@ def forward(self, x):

return output

def _apply_layernorms(self, dt, B, C):
if self.dt_layernorm is not None:
dt = self.dt_layernorm(dt)
if self.B_layernorm is not None:
B = self.B_layernorm(B)
if self.C_layernorm is not None:
C = self.C_layernorm(C)
return dt, B, C

def ssm(self, x, forward=True):
if forward:
A = -torch.exp(self.A_log_fwd.float())
Expand All @@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_fwd(delta))
else:
A = -torch.exp(self.A_log_bwd.float())
Expand All @@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_bwd(delta))

y = self.selective_scan_seq(x, delta, A, B, C, D)
Expand Down
Loading
Loading