Skip to content

Commit

Permalink
Merge pull request #98 from basf/AB_layer
Browse files Browse the repository at this point in the history
adjust names of matrices
  • Loading branch information
AnFreTh committed Aug 2, 2024
2 parents 0d4442a + 7c2d343 commit 8933cf7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
26 changes: 13 additions & 13 deletions mambular/arch_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=True,
AD_weight_decay=False,
BC_layer_norm=True,
):
super().__init__()

Expand All @@ -70,8 +70,8 @@ def __init__(
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AB_weight_decay,
AB_layer_norm,
AD_weight_decay,
BC_layer_norm,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -112,8 +112,8 @@ def __init__(
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
AD_weight_decay=False,
BC_layer_norm=False,
):
super().__init__()

Expand Down Expand Up @@ -159,8 +159,8 @@ def __init__(
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
layer_norm_eps=layer_norm_eps,
AB_weight_decay=AB_weight_decay,
AB_layer_norm=AB_layer_norm,
AD_weight_decay=AD_weight_decay,
BC_layer_norm=BC_layer_norm,
)
self.norm = norm(d_model, eps=layer_norm_eps)

Expand Down Expand Up @@ -202,8 +202,8 @@ def __init__(
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
AD_weight_decay=False,
BC_layer_norm=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
Expand Down Expand Up @@ -284,21 +284,21 @@ def __init__(
self.A_log_bwd = nn.Parameter(torch.log(A))
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
if not AD_weight_decay:
self.A_log_fwd._no_weight_decay = True
self.D_fwd._no_weight_decay = True

if self.bidirectional:

if not AB_weight_decay:
if not AD_weight_decay:
self.A_log_bwd._no_weight_decay = True
self.D_bwd._no_weight_decay = True

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state

if AB_layer_norm:
if BC_layer_norm:
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
Expand Down
4 changes: 2 additions & 2 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay),
AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm),
AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay),
BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm),
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
)
norm_layer = self.hparams.get("norm", config.norm)
Expand Down
10 changes: 6 additions & 4 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ class DefaultMambularConfig:
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AB_weight_decay : bool, default=False
wether weight decay is also applied to A-B matrices
AD_weight_decay : bool, default=False
whether weight decay is also applied to A-D matrices
BC_layer_norm: bool, default=True
whether to apply layer normalization to B-C matrices
"""

lr: float = 1e-04
Expand Down Expand Up @@ -112,5 +114,5 @@ class DefaultMambularConfig:
use_cls: bool = False
shuffle_embeddings: bool = False
layer_norm_eps: float = 1e-05
AB_weight_decay: bool = False
AB_layer_norm: bool = True
AD_weight_decay: bool = False
BC_layer_norm: bool = True

0 comments on commit 8933cf7

Please sign in to comment.