Skip to content

Commit

Permalink
Merge pull request #106 from basf/develop
Browse files Browse the repository at this point in the history
Version 0.2.1
  • Loading branch information
AnFreTh authored Aug 13, 2024
2 parents 57ca8fc + a383411 commit dad9a12
Show file tree
Hide file tree
Showing 41 changed files with 2,801 additions and 862 deletions.
22 changes: 0 additions & 22 deletions .github/workflows/draft-pdf.yml

This file was deleted.

615 changes: 352 additions & 263 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mambular/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.1.7"
__version__ = "0.2.1"
102 changes: 102 additions & 0 deletions mambular/arch_utils/attention_net_arch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch.nn as nn
import torch


import torch
import torch.nn as nn


class Reshape(nn.Module):
def __init__(self, j, dim, method="linear"):
super(Reshape, self).__init__()
self.j = j
self.dim = dim
self.method = method

if self.method == "linear":
# Use nn.Linear approach
self.layer = nn.Linear(dim, j * dim)
elif self.method == "embedding":
# Use nn.Embedding approach
self.layer = nn.Embedding(dim, j * dim)
elif self.method == "conv1d":
# Use nn.Conv1d approach
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
else:
raise ValueError(f"Unsupported method '{method}' for reshaping.")

def forward(self, x):
batch_size = x.shape[0]

if self.method == "linear" or self.method == "embedding":
x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)
elif self.method == "conv1d":
# For Conv1d, add dummy dimension and reshape
x = x.unsqueeze(-1) # Add dummy dimension for convolution
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
x_reshaped = x_reshaped.view(
batch_size, self.j, self.dim
) # shape: (batch_size, j, dim)

return x_reshaped


class AttentionNetBlock(nn.Module):
def __init__(
self,
channels,
in_channels,
d_model,
n_heads,
n_layers,
dim_feedforward,
transformer_activation,
output_dim,
attn_dropout,
layer_norm_eps,
norm_first,
bias,
activation,
embedding_activation,
norm_f,
method,
):
super(AttentionNetBlock, self).__init__()

self.reshape = Reshape(channels, in_channels, method)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_heads,
batch_first=True,
dim_feedforward=dim_feedforward,
dropout=attn_dropout,
activation=transformer_activation,
layer_norm_eps=layer_norm_eps,
norm_first=norm_first,
bias=bias,
)

self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=n_layers,
norm=norm_f,
)

self.linear = nn.Linear(d_model, output_dim)
self.activation = activation
self.embedding_activation = embedding_activation

def forward(self, x):
z = self.reshape(x)
x = self.embedding_activation(z)
x = self.encoder(x)
x = z + x
x = torch.sum(x, dim=1)
x = self.linear(x)
x = self.activation(x)
return x
97 changes: 97 additions & 0 deletions mambular/arch_utils/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch.nn as nn
import torch
from rotary_embedding_torch import RotaryEmbedding
from einops import rearrange
import torch.nn.functional as F
import numpy as np


class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)


def FeedForward(dim, mult=4, dropout=0.0):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rotary = rotary
dim = np.int64(dim / 2)
self.rotary_embedding = RotaryEmbedding(dim=dim)

def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
if self.rotary:
q = self.rotary_embedding.rotate_queries_or_keys(q)
k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale

sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)

attn = sim.softmax(dim=-1)
dropped_attn = self.dropout(attn)

out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
out = self.to_out(out)

return out, attn


class Transformer(nn.Module):
def __init__(
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
)
)

def forward(self, x, return_attn=False):
post_softmax_attns = []

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x

if not return_attn:
return x

return x, torch.stack(post_softmax_attns)
163 changes: 163 additions & 0 deletions mambular/arch_utils/embedding_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import torch.nn as nn


class EmbeddingLayer(nn.Module):
def __init__(
self,
num_feature_info,
cat_feature_info,
d_model,
embedding_activation=nn.Identity(),
layer_norm_after_embedding=False,
use_cls=False,
cls_position=0,
cat_encoding="int",
):
"""
Embedding layer that handles numerical and categorical embeddings.
Parameters
----------
num_feature_info : dict
Dictionary where keys are numerical feature names and values are their respective input dimensions.
cat_feature_info : dict
Dictionary where keys are categorical feature names and values are the number of categories for each feature.
d_model : int
Dimensionality of the embeddings.
embedding_activation : nn.Module, optional
Activation function to apply after embedding. Default is `nn.Identity()`.
layer_norm_after_embedding : bool, optional
If True, applies layer normalization after embeddings. Default is `False`.
use_cls : bool, optional
If True, includes a class token in the embeddings. Default is `False`.
cls_position : int, optional
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.
Methods
-------
forward(num_features=None, cat_features=None)
Defines the forward pass of the model.
"""
super(EmbeddingLayer, self).__init__()

self.d_model = d_model
self.embedding_activation = embedding_activation
self.layer_norm_after_embedding = layer_norm_after_embedding
self.use_cls = use_cls
self.cls_position = cls_position

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, d_model, bias=False),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

self.cat_embeddings = nn.ModuleList()
for feature_name, num_categories in cat_feature_info.items():
if cat_encoding == "int":
self.cat_embeddings.append(
nn.Sequential(
nn.Embedding(num_categories + 1, d_model),
self.embedding_activation,
)
)
elif cat_encoding == "one-hot":
self.cat_embeddings.append(
nn.Sequential(
OneHotEncoding(num_categories),
nn.Linear(num_categories, d_model, bias=False),
self.embedding_activation,
)
)

if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
if layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(d_model)

self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)

def forward(self, num_features=None, cat_features=None):
"""
Defines the forward pass of the model.
Parameters
----------
num_features : Tensor, optional
Tensor containing the numerical features.
cat_features : Tensor, optional
Tensor containing the categorical features.
Returns
-------
Tensor
The output embeddings of the model.
Raises
------
ValueError
If no features are provided to the model.
"""
if self.use_cls:
batch_size = (
cat_features[0].size(0)
if cat_features != []
else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

if self.cat_embeddings and cat_features is not None:
cat_embeddings = [
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

if self.num_embeddings and num_features is not None:
num_embeddings = [
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None

if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = cat_embeddings
elif num_embeddings is not None:
x = num_embeddings
else:
raise ValueError("No features provided to the model.")

if self.use_cls:
if self.cls_position == 0:
x = torch.cat([cls_tokens, x], dim=1)
elif self.cls_position == 1:
x = torch.cat([x, cls_tokens], dim=1)
else:
raise ValueError(
"Invalid cls_position value. It should be either 0 or 1."
)

return x


class OneHotEncoding(nn.Module):
def __init__(self, num_categories):
super(OneHotEncoding, self).__init__()
self.num_categories = num_categories

def forward(self, x):
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()
Loading

0 comments on commit dad9a12

Please sign in to comment.