Skip to content

Commit

Permalink
Merge pull request #47 from basf/restructure
Browse files Browse the repository at this point in the history
Restructure
  • Loading branch information
AnFreTh authored Jun 28, 2024
2 parents aa6b3be + 893b69c commit 7813141
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 52 deletions.
158 changes: 122 additions & 36 deletions mambular/arch_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(
dt_init_floor=1e-04,
norm=RMSNorm,
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
):
super().__init__()

Expand All @@ -62,6 +64,8 @@ def __init__(
dt_init_floor,
norm,
activation,
bidirectional,
use_learnable_interaction,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -99,6 +103,8 @@ def __init__(
dt_init_floor=1e-04,
norm=RMSNorm,
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
):
super().__init__()

Expand Down Expand Up @@ -141,6 +147,8 @@ def __init__(
dt_min=dt_min,
dt_init_floor=dt_init_floor,
activation=activation,
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
)
self.norm = norm(d_model)

Expand All @@ -153,14 +161,14 @@ class MambaBlock(nn.Module):
"""MambaBlock module containing the main computational components.
Attributes:
config (MambularConfig): Configuration object for the MambaBlock.
in_proj (nn.Linear): Linear projection for input.
conv1d (nn.Conv1d): 1D convolutional layer.
x_proj (nn.Linear): Linear projection for input-dependent tensors.
dt_proj (nn.Linear): Linear projection for dynamical time.
A_log (nn.Parameter): Logarithmically stored A tensor.
D (nn.Parameter): Tensor for D component.
out_proj (nn.Linear): Linear projection for output.
learnable_interaction (LearnableFeatureInteraction): Learnable feature interaction layer.
"""

def __init__(
Expand All @@ -179,88 +187,154 @@ def __init__(
dt_min=1e-03,
dt_init_floor=1e-04,
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
self.bidirectional = bidirectional
self.use_learnable_interaction = use_learnable_interaction

self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
self.in_proj_fwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)
if self.bidirectional:
self.in_proj_bwd = nn.Linear(d_model, 2 * self.d_inner, bias=bias)

self.conv1d = nn.Conv1d(
self.conv1d_fwd = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
bias=conv_bias,
groups=self.d_inner,
padding=d_conv - 1,
)
if self.bidirectional:
self.conv1d_bwd = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
bias=conv_bias,
groups=self.d_inner,
padding=d_conv - 1,
)

self.dropout = nn.Dropout(dropout)
self.activation = activation

self.x_proj = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)
if self.use_learnable_interaction:
self.learnable_interaction = LearnableFeatureInteraction(self.d_inner)

self.x_proj_fwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)
if self.bidirectional:
self.x_proj_bwd = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False)

self.dt_proj = nn.Linear(dt_rank, self.d_inner, bias=True)
self.dt_proj_fwd = nn.Linear(dt_rank, self.d_inner, bias=True)
if self.bidirectional:
self.dt_proj_bwd = nn.Linear(dt_rank, self.d_inner, bias=True)

dt_init_std = dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
nn.init.constant_(self.dt_proj_fwd.weight, dt_init_std)
if self.bidirectional:
nn.init.constant_(self.dt_proj_bwd.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
if self.bidirectional:
nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError

dt = torch.exp(
dt_fwd = torch.exp(
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
inv_dt_fwd = dt_fwd + torch.log(-torch.expm1(-dt_fwd))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
self.dt_proj_fwd.bias.copy_(inv_dt_fwd)

if self.bidirectional:
dt_bwd = torch.exp(
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt_bwd = dt_bwd + torch.log(-torch.expm1(-dt_bwd))
with torch.no_grad():
self.dt_proj_bwd.bias.copy_(inv_dt_bwd)

A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.A_log_fwd = nn.Parameter(torch.log(A))
if self.bidirectional:
self.A_log_bwd = nn.Parameter(torch.log(A))

self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
if self.bidirectional:
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state

def forward(self, x):
_, L, _ = x.shape

xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
xz_fwd = self.in_proj_fwd(x)
x_fwd, z_fwd = xz_fwd.chunk(2, dim=-1)

x = x.transpose(1, 2)
x = self.conv1d(x)[:, :, :L]
x = x.transpose(1, 2)
x_fwd = x_fwd.transpose(1, 2)
x_fwd = self.conv1d_fwd(x_fwd)[:, :, :L]
x_fwd = x_fwd.transpose(1, 2)

x = self.activation(x)
x = self.dropout(x)
y = self.ssm(x)
if self.bidirectional:
xz_bwd = self.in_proj_bwd(x)
x_bwd, z_bwd = xz_bwd.chunk(2, dim=-1)

z = self.activation(z)
z = self.dropout(z)
x_bwd = x_bwd.transpose(1, 2)
x_bwd = self.conv1d_bwd(x_bwd)[:, :, :L]
x_bwd = x_bwd.transpose(1, 2)

output = y * z
output = self.out_proj(output)
if self.use_learnable_interaction:
x_fwd = self.learnable_interaction(x_fwd)
if self.bidirectional:
x_bwd = self.learnable_interaction(x_bwd)

return output
x_fwd = self.activation(x_fwd)
x_fwd = self.dropout(x_fwd)
y_fwd = self.ssm(x_fwd, forward=True)

def ssm(self, x):
A = -torch.exp(self.A_log.float())
D = self.D.float()
if self.bidirectional:
x_bwd = self.activation(x_bwd)
x_bwd = self.dropout(x_bwd)
y_bwd = self.ssm(torch.flip(x_bwd, [1]), forward=False)
y = y_fwd + torch.flip(y_bwd, [1])
else:
y = y_fwd

deltaBC = self.x_proj(x)
z_fwd = self.activation(z_fwd)
z_fwd = self.dropout(z_fwd)

delta, B, C = torch.split(
deltaBC,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
delta = F.softplus(self.dt_proj(delta))
output = y * z_fwd
output = self.out_proj(output)

y = self.selective_scan_seq(x, delta, A, B, C, D)
return output

def ssm(self, x, forward=True):
if forward:
A = -torch.exp(self.A_log_fwd.float())
D = self.D_fwd.float()
deltaBC = self.x_proj_fwd(x)
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta = F.softplus(self.dt_proj_fwd(delta))
else:
A = -torch.exp(self.A_log_bwd.float())
D = self.D_bwd.float()
deltaBC = self.x_proj_bwd(x)
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta = F.softplus(self.dt_proj_bwd(delta))

y = self.selective_scan_seq(x, delta, A, B, C, D)
return y

def selective_scan_seq(self, x, delta, A, B, C, D):
Expand All @@ -285,3 +359,15 @@ def selective_scan_seq(self, x, delta, A, B, C, D):
y = y + D * x

return y


class LearnableFeatureInteraction(nn.Module):
def __init__(self, n_vars):
super().__init__()
self.interaction_weights = nn.Parameter(torch.Tensor(n_vars, n_vars))
nn.init.xavier_uniform_(self.interaction_weights)

def forward(self, x):
batch_size, n_vars, d_model = x.size()
interactions = torch.matmul(x, self.interaction_weights)
return interactions.view(batch_size, n_vars, d_model)
4 changes: 4 additions & 0 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __init__(
dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
norm=globals()[self.hparams.get("norm", config.norm)],
activation=self.hparams.get("activation", config.activation),
bidirectional=self.hparams.get("bidiretional", config.bidirectional),
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
)

norm_layer = self.hparams.get("norm", config.norm)
Expand Down
2 changes: 1 addition & 1 deletion mambular/configs/fttransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ class DefaultFTTransformerConfig:
bias: bool = True
transformer_activation: callable = nn.SELU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 2048
transformer_dim_feedforward: int = 512
2 changes: 2 additions & 0 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ class DefaultMambularConfig:
head_use_batch_norm: bool = False
layer_norm_after_embedding: bool = False
pooling_method: str = "avg"
bidirectional: bool = False
use_learnable_interaction: bool = False
2 changes: 1 addition & 1 deletion mambular/configs/tabtransformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ class DefaultTabTransformerConfig:
bias: bool = True
transformer_activation: callable = nn.SELU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 2048
transformer_dim_feedforward: int = 512
2 changes: 2 additions & 0 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, model, config, **kwargs):
"task",
"cat_cutoff",
"treat_all_integers_as_numerical",
"knots",
"degree",
]

self.config_kwargs = {
Expand Down
2 changes: 2 additions & 0 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self, model, config, **kwargs):
"task",
"cat_cutoff",
"treat_all_integers_as_numerical",
"knots",
"degree",
]

self.config_kwargs = {
Expand Down
2 changes: 2 additions & 0 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(self, model, config, **kwargs):
"task",
"cat_cutoff",
"treat_all_integers_as_numerical",
"knots",
"degree",
]

self.config_kwargs = {
Expand Down
Loading

0 comments on commit 7813141

Please sign in to comment.