diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_arch.py index 6417c7f..537b8e5 100644 --- a/mambular/arch_utils/mamba_arch.py +++ b/mambular/arch_utils/mamba_arch.py @@ -44,8 +44,8 @@ def __init__( bidirectional=False, use_learnable_interaction=False, layer_norm_eps=1e-05, - AB_weight_decay=False, - AB_layer_norm=True, + AD_weight_decay=False, + BC_layer_norm=True, ): super().__init__() @@ -70,8 +70,8 @@ def __init__( bidirectional, use_learnable_interaction, layer_norm_eps, - AB_weight_decay, - AB_layer_norm, + AD_weight_decay, + BC_layer_norm, ) for _ in range(n_layers) ] @@ -112,8 +112,8 @@ def __init__( bidirectional=False, use_learnable_interaction=False, layer_norm_eps=1e-05, - AB_weight_decay=False, - AB_layer_norm=False, + AD_weight_decay=False, + BC_layer_norm=False, ): super().__init__() @@ -159,8 +159,8 @@ def __init__( bidirectional=bidirectional, use_learnable_interaction=use_learnable_interaction, layer_norm_eps=layer_norm_eps, - AB_weight_decay=AB_weight_decay, - AB_layer_norm=AB_layer_norm, + AD_weight_decay=AD_weight_decay, + BC_layer_norm=BC_layer_norm, ) self.norm = norm(d_model, eps=layer_norm_eps) @@ -202,8 +202,8 @@ def __init__( bidirectional=False, use_learnable_interaction=False, layer_norm_eps=1e-05, - AB_weight_decay=False, - AB_layer_norm=False, + AD_weight_decay=False, + BC_layer_norm=False, ): super().__init__() self.d_inner = d_model * expand_factor @@ -284,13 +284,13 @@ def __init__( self.A_log_bwd = nn.Parameter(torch.log(A)) self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) - if not AB_weight_decay: + if not AD_weight_decay: self.A_log_fwd._no_weight_decay = True self.D_fwd._no_weight_decay = True if self.bidirectional: - if not AB_weight_decay: + if not AD_weight_decay: self.A_log_bwd._no_weight_decay = True self.D_bwd._no_weight_decay = True @@ -298,7 +298,7 @@ def __init__( self.dt_rank = dt_rank self.d_state = d_state - if AB_layer_norm: + if BC_layer_norm: self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index 53d5a3e..33b2b6f 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -109,8 +109,8 @@ def __init__( use_learnable_interaction=self.hparams.get( "use_learnable_interactions", config.use_learnable_interaction ), - AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay), - AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm), + AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay), + BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm), layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), ) norm_layer = self.hparams.get("norm", config.norm) diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index c9b8afa..c6fcd89 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -75,8 +75,10 @@ class DefaultMambularConfig: Whether to shuffle the embeddings before being passed to the Mamba layers. layer_norm_eps : float, default=1e-05 Epsilon value for layer normalization. - AB_weight_decay : bool, default=False - wether weight decay is also applied to A-B matrices + AD_weight_decay : bool, default=False + whether weight decay is also applied to A-D matrices + BC_layer_norm: bool, default=True + whether to apply layer normalization to B-C matrices """ lr: float = 1e-04 @@ -112,5 +114,5 @@ class DefaultMambularConfig: use_cls: bool = False shuffle_embeddings: bool = False layer_norm_eps: float = 1e-05 - AB_weight_decay: bool = False - AB_layer_norm: bool = True + AD_weight_decay: bool = False + BC_layer_norm: bool = True