Skip to content

Commit

Permalink
Merge pull request #90 from basf/layer_improvement
Browse files Browse the repository at this point in the history
Layer improvement
  • Loading branch information
AnFreTh authored Jul 26, 2024
2 parents cc92798 + 19b760c commit 56801dd
Show file tree
Hide file tree
Showing 28 changed files with 1,093 additions and 261 deletions.
102 changes: 102 additions & 0 deletions mambular/arch_utils/attention_net_arch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch.nn as nn
import torch


import torch
import torch.nn as nn


class Reshape(nn.Module):
def __init__(self, j, dim, method="linear"):
super(Reshape, self).__init__()
self.j = j
self.dim = dim
self.method = method

if self.method == "linear":
# Use nn.Linear approach
self.layer = nn.Linear(dim, j * dim)
elif self.method == "embedding":
# Use nn.Embedding approach
self.layer = nn.Embedding(dim, j * dim)
elif self.method == "conv1d":
# Use nn.Conv1d approach
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
else:
raise ValueError(f"Unsupported method '{method}' for reshaping.")

def forward(self, x):
batch_size = x.shape[0]

if self.method == "linear" or self.method == "embedding":
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)
elif self.method == "conv1d":
# For Conv1d, add dummy dimension and reshape
x = x.unsqueeze(-1) # Add dummy dimension for convolution
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)

return x_reshaped


class AttentionNetBlock(nn.Module):
def __init__(
self,
channels,
in_channels,
d_model,
n_heads,
n_layers,
dim_feedforward,
transformer_activation,
output_dim,
attn_dropout,
layer_norm_eps,
norm_first,
bias,
activation,
embedding_activation,
norm_f,
method,
):
super(AttentionNetBlock, self).__init__()

self.reshape = Reshape(channels, in_channels, method)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
batch_first=True,
dim_feedforward=dim_feedforward,
dropout=attn_dropout,
activation=transformer_activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
bias=bias,
)

self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=n_layers,
norm=norm_f,
)

self.linear = nn.Linear(d_model, output_dim)
self.activation = activation
self.embedding_activation = embedding_activation

def forward(self, x):
z = self.reshape(x)
x = self.embedding_activation(z)
x = self.encoder(x)
x = z + x
x = torch.sum(x, dim=1)
x = self.linear(x)
x = self.activation(x)
return x
97 changes: 97 additions & 0 deletions mambular/arch_utils/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch.nn as nn
import torch
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange
import torch.nn.functional as F
import numpy as np


class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


def FeedForward(dim, mult=4, dropout=0.0):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rotary = rotary
dim = np.int64(dim / 2)
self.rotary_embedding = RotaryEmbedding(dim=dim)

def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
if self.rotary:
q = self.rotary_embedding.rotate_queries_or_keys(q)
k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale

sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)

attn = sim.softmax(dim=-1)
dropped_attn = self.dropout(attn)

out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)

return out, attn


class Transformer(nn.Module):
def __init__(
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
)
)

def forward(self, x, return_attn=False):
post_softmax_attns = []

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)
145 changes: 145 additions & 0 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import torch.nn as nn


class EmbeddingLayer(nn.Module):
def __init__(
self,
num_feature_info,
cat_feature_info,
d_model,
embedding_activation=nn.Identity(),
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
):
"""
Embedding layer that handles numerical and categorical embeddings.
Parameters
----------
num_feature_info : dict
Dictionary where keys are numerical feature names and values are their respective input dimensions.
cat_feature_info : dict
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
d_model : int
Dimensionality of the embeddings.
embedding_activation : nn.Module, optional
Activation function to apply after embedding. Default is `nn.Identity()`.
layer_norm_after_embedding : bool, optional
If True, applies layer normalization after embeddings. Default is `False`.
use_cls : bool, optional
If True, includes a class token in the embeddings. Default is `False`.
cls_position : int, optional
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.
Methods
-------
forward(num_features=None, cat_features=None)
Defines the forward pass of the model.
"""
super(EmbeddingLayer, self).__init__()

self.d_model = d_model
self.embedding_activation = embedding_activation
self.layer_norm_after_embedding = layer_norm_after_embedding
self.use_cls = use_cls
self.cls_position = cls_position

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, d_model, bias=False),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

self.cat_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
for feature_name, num_categories in cat_feature_info.items()
]
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
if layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(d_model)

self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)

def forward(self, num_features=None, cat_features=None):
"""
Defines the forward pass of the model.
Parameters
----------
num_features : Tensor, optional
Tensor containing the numerical features.
cat_features : Tensor, optional
Tensor containing the categorical features.
Returns
-------
Tensor
The output embeddings of the model.
Raises
------
ValueError
If no features are provided to the model.
"""
if self.use_cls:
batch_size = (
cat_features[0].size(0)
if cat_features != []
else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

if self.cat_embeddings and cat_features is not None:
cat_embeddings = [
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

if self.num_embeddings and num_features is not None:
num_embeddings = [
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None

if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = cat_embeddings
elif num_embeddings is not None:
x = num_embeddings
else:
raise ValueError("No features provided to the model.")

if self.use_cls:
if self.cls_position == 0:
x = torch.cat([cls_tokens, x], dim=1)
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1)
else:
raise ValueError(
"Invalid cls_position value. It should be either 0 or 1."
)

return x
38 changes: 38 additions & 0 deletions mambular/arch_utils/learnable_ple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import torch.nn as nn


class PeriodicLinearEncodingLayer(nn.Module):
def __init__(self, bins=10, learn_bins=True):
super(PeriodicLinearEncodingLayer, self).__init__()
self.bins = bins
self.learn_bins = learn_bins

if self.learn_bins:
# Learnable bin boundaries
self.bin_boundaries = nn.Parameter(torch.linspace(0, 1, self.bins + 1))
else:
self.bin_boundaries = torch.linspace(-1, 1, self.bins + 1)

def forward(self, x):
if self.learn_bins:
# Ensure bin boundaries are sorted
sorted_bins = torch.sort(self.bin_boundaries)[0]
else:
sorted_bins = self.bin_boundaries

# Initialize z with zeros
z = torch.zeros(x.size(0), self.bins, device=x.device)

for t in range(1, self.bins + 1):
b_t_1 = sorted_bins[t - 1]
b_t = sorted_bins[t]
mask1 = x < b_t_1
mask2 = x >= b_t
mask3 = (x >= b_t_1) & (x < b_t)

z[mask1.squeeze(), t - 1] = 0
z[mask2.squeeze(), t - 1] = 1
z[mask3.squeeze(), t - 1] = (x[mask3] - b_t_1) / (b_t - b_t_1)

return z
Loading

0 comments on commit 56801dd

Please sign in to comment.