Skip to content

Commit

Permalink
adding hotfix for mamba arch
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Aug 2, 2024
1 parent cc92798 commit 6a6a46a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 84 deletions.
51 changes: 48 additions & 3 deletions mambular/arch_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=True,
):
super().__init__()

Expand All @@ -66,6 +69,9 @@ def __init__(
activation,
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AB_weight_decay,
AB_layer_norm,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -105,6 +111,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()

Expand Down Expand Up @@ -149,8 +158,11 @@ def __init__(
activation=activation,
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
layer_norm_eps=layer_norm_eps,
AB_weight_decay=AB_weight_decay,
AB_layer_norm=AB_layer_norm,
)
self.norm = norm(d_model)
self.norm = norm(d_model, eps=layer_norm_eps)

def forward(self, x):
output = self.layers(self.norm(x)) + x
Expand Down Expand Up @@ -189,6 +201,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
Expand Down Expand Up @@ -239,6 +254,7 @@ def __init__(
elif dt_init == "random":
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
if self.bidirectional:

nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
Expand All @@ -262,17 +278,35 @@ def __init__(

A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log_fwd = nn.Parameter(torch.log(A))
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))

if self.bidirectional:
self.A_log_bwd = nn.Parameter(torch.log(A))
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_fwd._no_weight_decay = True
self.D_fwd._no_weight_decay = True

self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
if self.bidirectional:
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_bwd._no_weight_decay = True
self.D_bwd._no_weight_decay = True

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state

if AB_layer_norm:
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
else:
self.dt_layernorm = None
self.B_layernorm = None
self.C_layernorm = None

def forward(self, x):
_, L, _ = x.shape

Expand Down Expand Up @@ -316,6 +350,15 @@ def forward(self, x):

return output

def _apply_layernorms(self, dt, B, C):
if self.dt_layernorm is not None:
dt = self.dt_layernorm(dt)
if self.B_layernorm is not None:
B = self.B_layernorm(B)
if self.C_layernorm is not None:
C = self.C_layernorm(C)
return dt, B, C

def ssm(self, x, forward=True):
if forward:
A = -torch.exp(self.A_log_fwd.float())
Expand All @@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_fwd(delta))
else:
A = -torch.exp(self.A_log_bwd.float())
Expand All @@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_bwd(delta))

y = self.selective_scan_seq(x, delta, A, B, C, D)
Expand Down
25 changes: 20 additions & 5 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,34 @@ def __init__(
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay),
AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm),
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
)

norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = RMSNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = LayerNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = BatchNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = InstanceNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
self.norm_f = GroupNorm(
1,
self.hparams.get("d_model", config.d_model),
eps=config.layer_norm_eps,
)
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
Expand Down
18 changes: 14 additions & 4 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class DefaultMambularConfig:
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the model.
num_embedding_activation : callable, default=nn.Identity()
Activation function for numerical embeddings.
embedding_activation : callable, default=nn.Identity()
Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
Expand All @@ -70,7 +70,13 @@ class DefaultMambularConfig:
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through mamba blocks.
use_cls : bool, default=True
Whether to append a cls to the beginning of each 'sequence'.
Whether to append a cls to the end of each 'sequence'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AB_weight_decay : bool, default=False
wether weight decay is also applied to A-B matrices
"""

lr: float = 1e-04
Expand All @@ -93,7 +99,7 @@ class DefaultMambularConfig:
dt_init_floor: float = 1e-04
norm: str = "LayerNorm"
activation: callable = nn.SiLU()
num_embedding_activation: callable = nn.Identity()
embedding_activation: callable = nn.Identity()
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
Expand All @@ -104,3 +110,7 @@ class DefaultMambularConfig:
bidirectional: bool = False
use_learnable_interaction: bool = False
use_cls: bool = False
shuffle_embeddings: bool = False
layer_norm_eps: float = 1e-05
AB_weight_decay: bool = False
AB_layer_norm: bool = True
2 changes: 1 addition & 1 deletion mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def fit(
self : object
The fitted classifier.
"""
if not self.built and not rebuild:
if (not self.built) or (self.built and rebuild):
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
Expand Down
76 changes: 39 additions & 37 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def fit(
checkpoint_path="model_checkpoints",
distributional_kwargs=None,
dataloader_kwargs={},
rebuild=True,
**trainer_kwargs
):
"""
Expand Down Expand Up @@ -357,45 +358,46 @@ def fit(
else:
raise ValueError("Unsupported family: {}".format(family))

if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
y_val = y_val.values

self.data_module = MambularDataModule(
preprocessor=self.preprocessor,
batch_size=batch_size,
shuffle=shuffle,
X_val=X_val,
y_val=y_val,
val_size=val_size,
random_state=random_state,
regression=True,
**dataloader_kwargs
)
if (not self.built) or (self.built and rebuild):
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
y_val = y_val.values

self.data_module = MambularDataModule(
preprocessor=self.preprocessor,
batch_size=batch_size,
shuffle=shuffle,
X_val=X_val,
y_val=y_val,
val_size=val_size,
random_state=random_state,
regression=True,
**dataloader_kwargs
)

self.data_module.preprocess_data(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)
self.data_module.preprocess_data(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)

self.model = TaskModel(
model_class=self.base_model,
num_classes=self.family.param_count,
family=self.family,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr,
lr_patience=lr_patience,
lr_factor=factor,
weight_decay=weight_decay,
lss=True,
)
self.model = TaskModel(
model_class=self.base_model,
num_classes=self.family.param_count,
family=self.family,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr,
lr_patience=lr_patience,
lr_factor=factor,
weight_decay=weight_decay,
lss=True,
)

early_stop_callback = EarlyStopping(
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
Expand Down
Loading

0 comments on commit 6a6a46a

Please sign in to comment.