-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #90 from basf/layer_improvement
Layer improvement
- Loading branch information
Showing
28 changed files
with
1,093 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch.nn as nn | ||
import torch | ||
|
||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Reshape(nn.Module): | ||
def __init__(self, j, dim, method="linear"): | ||
super(Reshape, self).__init__() | ||
self.j = j | ||
self.dim = dim | ||
self.method = method | ||
|
||
if self.method == "linear": | ||
# Use nn.Linear approach | ||
self.layer = nn.Linear(dim, j * dim) | ||
elif self.method == "embedding": | ||
# Use nn.Embedding approach | ||
self.layer = nn.Embedding(dim, j * dim) | ||
elif self.method == "conv1d": | ||
# Use nn.Conv1d approach | ||
self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1) | ||
else: | ||
raise ValueError(f"Unsupported method '{method}' for reshaping.") | ||
|
||
def forward(self, x): | ||
batch_size = x.shape[0] | ||
|
||
if self.method == "linear" or self.method == "embedding": | ||
x_reshaped = self.layer(x) # shape: (batch_size, j * dim) | ||
x_reshaped = x_reshaped.view( | ||
batch_size, self.j, self.dim | ||
) # shape: (batch_size, j, dim) | ||
elif self.method == "conv1d": | ||
# For Conv1d, add dummy dimension and reshape | ||
x = x.unsqueeze(-1) # Add dummy dimension for convolution | ||
x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1) | ||
x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension | ||
x_reshaped = x_reshaped.view( | ||
batch_size, self.j, self.dim | ||
) # shape: (batch_size, j, dim) | ||
|
||
return x_reshaped | ||
|
||
|
||
class AttentionNetBlock(nn.Module): | ||
def __init__( | ||
self, | ||
channels, | ||
in_channels, | ||
d_model, | ||
n_heads, | ||
n_layers, | ||
dim_feedforward, | ||
transformer_activation, | ||
output_dim, | ||
attn_dropout, | ||
layer_norm_eps, | ||
norm_first, | ||
bias, | ||
activation, | ||
embedding_activation, | ||
norm_f, | ||
method, | ||
): | ||
super(AttentionNetBlock, self).__init__() | ||
|
||
self.reshape = Reshape(channels, in_channels, method) | ||
|
||
encoder_layer = nn.TransformerEncoderLayer( | ||
d_model=d_model, | ||
nhead=n_heads, | ||
batch_first=True, | ||
dim_feedforward=dim_feedforward, | ||
dropout=attn_dropout, | ||
activation=transformer_activation, | ||
layer_norm_eps=layer_norm_eps, | ||
norm_first=norm_first, | ||
bias=bias, | ||
) | ||
|
||
self.encoder = nn.TransformerEncoder( | ||
encoder_layer, | ||
num_layers=n_layers, | ||
norm=norm_f, | ||
) | ||
|
||
self.linear = nn.Linear(d_model, output_dim) | ||
self.activation = activation | ||
self.embedding_activation = embedding_activation | ||
|
||
def forward(self, x): | ||
z = self.reshape(x) | ||
x = self.embedding_activation(z) | ||
x = self.encoder(x) | ||
x = z + x | ||
x = torch.sum(x, dim=1) | ||
x = self.linear(x) | ||
x = self.activation(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch.nn as nn | ||
import torch | ||
from rotary_embedding_torch import RotaryEmbedding | ||
from einops import rearrange | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
|
||
class GEGLU(nn.Module): | ||
def forward(self, x): | ||
x, gates = x.chunk(2, dim=-1) | ||
return x * F.gelu(gates) | ||
|
||
|
||
def FeedForward(dim, mult=4, dropout=0.0): | ||
return nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, dim * mult * 2), | ||
GEGLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(dim * mult, dim), | ||
) | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
self.heads = heads | ||
self.scale = dim_head**-0.5 | ||
self.norm = nn.LayerNorm(dim) | ||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias=False) | ||
self.dropout = nn.Dropout(dropout) | ||
self.rotary = rotary | ||
dim = np.int64(dim / 2) | ||
self.rotary_embedding = RotaryEmbedding(dim=dim) | ||
|
||
def forward(self, x): | ||
h = self.heads | ||
x = self.norm(x) | ||
q, k, v = self.to_qkv(x).chunk(3, dim=-1) | ||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) | ||
if self.rotary: | ||
q = self.rotary_embedding.rotate_queries_or_keys(q) | ||
k = self.rotary_embedding.rotate_queries_or_keys(k) | ||
q = q * self.scale | ||
|
||
sim = torch.einsum("b h i d, b h j d -> b h i j", q, k) | ||
|
||
attn = sim.softmax(dim=-1) | ||
dropped_attn = self.dropout(attn) | ||
|
||
out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v) | ||
out = rearrange(out, "b h n d -> b n (h d)", h=h) | ||
out = self.to_out(out) | ||
|
||
return out, attn | ||
|
||
|
||
class Transformer(nn.Module): | ||
def __init__( | ||
self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False | ||
): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
|
||
for _ in range(depth): | ||
self.layers.append( | ||
nn.ModuleList( | ||
[ | ||
Attention( | ||
dim, | ||
heads=heads, | ||
dim_head=dim_head, | ||
dropout=attn_dropout, | ||
rotary=rotary, | ||
), | ||
FeedForward(dim, dropout=ff_dropout), | ||
] | ||
) | ||
) | ||
|
||
def forward(self, x, return_attn=False): | ||
post_softmax_attns = [] | ||
|
||
for attn, ff in self.layers: | ||
attn_out, post_softmax_attn = attn(x) | ||
post_softmax_attns.append(post_softmax_attn) | ||
|
||
x = attn_out + x | ||
x = ff(x) + x | ||
|
||
if not return_attn: | ||
return x | ||
|
||
return x, torch.stack(post_softmax_attns) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class EmbeddingLayer(nn.Module): | ||
def __init__( | ||
self, | ||
num_feature_info, | ||
cat_feature_info, | ||
d_model, | ||
embedding_activation=nn.Identity(), | ||
layer_norm_after_embedding=False, | ||
use_cls=False, | ||
cls_position=0, | ||
): | ||
""" | ||
Embedding layer that handles numerical and categorical embeddings. | ||
Parameters | ||
---------- | ||
num_feature_info : dict | ||
Dictionary where keys are numerical feature names and values are their respective input dimensions. | ||
cat_feature_info : dict | ||
Dictionary where keys are categorical feature names and values are the number of categories for each feature. | ||
d_model : int | ||
Dimensionality of the embeddings. | ||
embedding_activation : nn.Module, optional | ||
Activation function to apply after embedding. Default is `nn.Identity()`. | ||
layer_norm_after_embedding : bool, optional | ||
If True, applies layer normalization after embeddings. Default is `False`. | ||
use_cls : bool, optional | ||
If True, includes a class token in the embeddings. Default is `False`. | ||
cls_position : int, optional | ||
Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`. | ||
Methods | ||
------- | ||
forward(num_features=None, cat_features=None) | ||
Defines the forward pass of the model. | ||
""" | ||
super(EmbeddingLayer, self).__init__() | ||
|
||
self.d_model = d_model | ||
self.embedding_activation = embedding_activation | ||
self.layer_norm_after_embedding = layer_norm_after_embedding | ||
self.use_cls = use_cls | ||
self.cls_position = cls_position | ||
|
||
self.num_embeddings = nn.ModuleList( | ||
[ | ||
nn.Sequential( | ||
nn.Linear(input_shape, d_model, bias=False), | ||
self.embedding_activation, | ||
) | ||
for feature_name, input_shape in num_feature_info.items() | ||
] | ||
) | ||
|
||
self.cat_embeddings = nn.ModuleList( | ||
[ | ||
nn.Sequential( | ||
nn.Embedding(num_categories + 1, d_model), | ||
self.embedding_activation, | ||
) | ||
for feature_name, num_categories in cat_feature_info.items() | ||
] | ||
) | ||
|
||
if self.use_cls: | ||
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) | ||
if layer_norm_after_embedding: | ||
self.embedding_norm = nn.LayerNorm(d_model) | ||
|
||
self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings) | ||
|
||
def forward(self, num_features=None, cat_features=None): | ||
""" | ||
Defines the forward pass of the model. | ||
Parameters | ||
---------- | ||
num_features : Tensor, optional | ||
Tensor containing the numerical features. | ||
cat_features : Tensor, optional | ||
Tensor containing the categorical features. | ||
Returns | ||
------- | ||
Tensor | ||
The output embeddings of the model. | ||
Raises | ||
------ | ||
ValueError | ||
If no features are provided to the model. | ||
""" | ||
if self.use_cls: | ||
batch_size = ( | ||
cat_features[0].size(0) | ||
if cat_features != [] | ||
else num_features[0].size(0) | ||
) | ||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) | ||
|
||
if self.cat_embeddings and cat_features is not None: | ||
cat_embeddings = [ | ||
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings) | ||
] | ||
cat_embeddings = torch.stack(cat_embeddings, dim=1) | ||
cat_embeddings = torch.squeeze(cat_embeddings, dim=2) | ||
if self.layer_norm_after_embedding: | ||
cat_embeddings = self.embedding_norm(cat_embeddings) | ||
else: | ||
cat_embeddings = None | ||
|
||
if self.num_embeddings and num_features is not None: | ||
num_embeddings = [ | ||
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings) | ||
] | ||
num_embeddings = torch.stack(num_embeddings, dim=1) | ||
if self.layer_norm_after_embedding: | ||
num_embeddings = self.embedding_norm(num_embeddings) | ||
else: | ||
num_embeddings = None | ||
|
||
if cat_embeddings is not None and num_embeddings is not None: | ||
x = torch.cat([cat_embeddings, num_embeddings], dim=1) | ||
elif cat_embeddings is not None: | ||
x = cat_embeddings | ||
elif num_embeddings is not None: | ||
x = num_embeddings | ||
else: | ||
raise ValueError("No features provided to the model.") | ||
|
||
if self.use_cls: | ||
if self.cls_position == 0: | ||
x = torch.cat([cls_tokens, x], dim=1) | ||
elif self.cls_position == 1: | ||
x = torch.cat([x, cls_tokens], dim=1) | ||
else: | ||
raise ValueError( | ||
"Invalid cls_position value. It should be either 0 or 1." | ||
) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class PeriodicLinearEncodingLayer(nn.Module): | ||
def __init__(self, bins=10, learn_bins=True): | ||
super(PeriodicLinearEncodingLayer, self).__init__() | ||
self.bins = bins | ||
self.learn_bins = learn_bins | ||
|
||
if self.learn_bins: | ||
# Learnable bin boundaries | ||
self.bin_boundaries = nn.Parameter(torch.linspace(0, 1, self.bins + 1)) | ||
else: | ||
self.bin_boundaries = torch.linspace(-1, 1, self.bins + 1) | ||
|
||
def forward(self, x): | ||
if self.learn_bins: | ||
# Ensure bin boundaries are sorted | ||
sorted_bins = torch.sort(self.bin_boundaries)[0] | ||
else: | ||
sorted_bins = self.bin_boundaries | ||
|
||
# Initialize z with zeros | ||
z = torch.zeros(x.size(0), self.bins, device=x.device) | ||
|
||
for t in range(1, self.bins + 1): | ||
b_t_1 = sorted_bins[t - 1] | ||
b_t = sorted_bins[t] | ||
mask1 = x < b_t_1 | ||
mask2 = x >= b_t | ||
mask3 = (x >= b_t_1) & (x < b_t) | ||
|
||
z[mask1.squeeze(), t - 1] = 0 | ||
z[mask2.squeeze(), t - 1] = 1 | ||
z[mask3.squeeze(), t - 1] = (x[mask3] - b_t_1) / (b_t - b_t_1) | ||
|
||
return z |
Oops, something went wrong.