Skip to content

Commit

Permalink
include tabularRNN
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Aug 12, 2024
1 parent 91ab62c commit 43d2758
Show file tree
Hide file tree
Showing 4 changed files with 495 additions and 0 deletions.
153 changes: 153 additions & 0 deletions mambular/base_models/tabularnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch
import torch.nn as nn
from ..arch_utils.mlp_utils import MLP
from ..configs.tabularnn_config import DefaultTabulaRNNConfig
from .basemodel import BaseModel
from ..arch_utils.embedding_layer import EmbeddingLayer
from ..arch_utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)


class TabulaRNN(BaseModel):
def __init__(
self,
cat_feature_info,
num_feature_info,
num_classes=1,
config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(),
**kwargs,
):
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info

norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(
1, self.hparams.get("dim_feedforward", config.dim_feedforward)
)
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("dim_feedforward", config.dim_feedforward)
)
else:
self.norm_f = None

rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type]
self.rnn = rnn_layer(
input_size=self.hparams.get("d_model", config.d_model),
hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward),
num_layers=self.hparams.get("n_layers", config.n_layers),
bidirectional=self.hparams.get("bidirectional", config.bidirectional),
batch_first=True,
dropout=self.hparams.get("rnn_dropout", config.rnn_dropout),
bias=self.hparams.get("bias", config.bias),
nonlinearity=(
self.hparams.get("rnn_activation", config.rnn_activation)
if config.model_type == "RNN"
else None
),
)

self.embedding_layer = EmbeddingLayer(
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
d_model=self.hparams.get("d_model", config.d_model),
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding", config.layer_norm_after_embedding
),
use_cls=False,
cls_position=-1,
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)

head_activation = self.hparams.get("head_activation", config.head_activation)

self.tabular_head = MLP(
self.hparams.get("dim_feedforward", config.dim_feedforward),
hidden_units_list=self.hparams.get(
"head_layer_sizes", config.head_layer_sizes
),
dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
use_skip_layers=self.hparams.get(
"head_skip_layers", config.head_skip_layers
),
activation_fn=head_activation,
use_batch_norm=self.hparams.get(
"head_use_batch_norm", config.head_use_batch_norm
),
n_output_units=num_classes,
)

self.linear = nn.Linear(config.d_model, config.dim_feedforward)

def forward(self, num_features, cat_features):
"""
Defines the forward pass of the model.
Parameters
----------
num_features : Tensor
Tensor containing the numerical features.
cat_features : Tensor
Tensor containing the categorical features.
Returns
-------
Tensor
The output predictions of the model.
"""

x = self.embedding_layer(num_features, cat_features)
# RNN forward pass
out, _ = self.rnn(x)
z = self.linear(torch.mean(x, dim=1))

if self.pooling_method == "avg":
x = torch.mean(out, dim=1)
elif self.pooling_method == "max":
x, _ = torch.max(out, dim=1)
elif self.pooling_method == "sum":
x = torch.sum(out, dim=1)
elif self.pooling_method == "last":
x = x[:, -1, :]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
x = x + z
if self.norm_f is not None:
x = self.norm_f(x)
preds = self.tabular_head(x)

return preds
83 changes: 83 additions & 0 deletions mambular/configs/tabularnn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from dataclasses import dataclass
import torch.nn as nn


@dataclass
class DefaultTabulaRNNConfig:
"""
Configuration class for the default TabulaRNN model with predefined hyperparameters.
Parameters
----------
lr : float, default=1e-04
Learning rate for the optimizer.
model_type : str, default="RNN"
type of model, one of "RNN", "LSTM", "GRU"
lr_patience : int, default=10
Number of epochs with no improvement after which learning rate will be reduced.
weight_decay : float, default=1e-06
Weight decay (L2 penalty) for the optimizer.
lr_factor : float, default=0.1
Factor by which the learning rate will be reduced.
d_model : int, default=64
Dimensionality of the model.
n_layers : int, default=8
Number of layers in the transformer.
norm : str, default="RMSNorm"
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
embedding_activation : callable, default=nn.Identity()
Activation function for numerical embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
Dropout rate for the head layers.
head_skip_layers : bool, default=False
Whether to skip layers in the head.
head_activation : callable, default=nn.SELU()
Activation function for the head layers.
head_use_batch_norm : bool, default=False
Whether to use batch normalization in the head layers.
layer_norm_after_embedding : bool, default=False
Whether to apply layer normalization after embedding.
pooling_method : str, default="cls"
Pooling method to be used ('cls', 'avg', etc.).
norm_first : bool, default=False
Whether to apply normalization before other operations in each transformer block.
bias : bool, default=True
Whether to use bias in the linear layers.
rnn_activation : callable, default=nn.SELU()
Activation function for the transformer layers.
bidirectional : bool, default=False.
Whether to process data bidirectionally
cat_encoding : str, default="int"
Encoding method for categorical features.
"""

lr: float = 1e-04
model_type: str = "RNN"
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 128
n_layers: int = 4
rnn_dropout: float = 0.2
norm: str = "RMSNorm"
activation: callable = nn.SELU()
embedding_activation: callable = nn.Identity()
head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
head_use_batch_norm: bool = False
layer_norm_after_embedding: bool = False
pooling_method: str = "avg"
norm_first: bool = False
bias: bool = True
rnn_activation: str = "relu"
layer_norm_eps: float = 1e-05
dim_feedforward: int = 256
numerical_embedding: str = "ple"
bidirectional: bool = False
cat_encoding: str = "int"
4 changes: 4 additions & 0 deletions mambular/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from .mambatab import MambaTabClassifier, MambaTabRegressor, MambaTabLSS
from .tabularnn import TabulaRNNClassifier, TabulaRNNRegressor, TabulaRNNLSS


__all__ = [
Expand All @@ -40,4 +41,7 @@
"MambaTabRegressor",
"MambaTabClassifier",
"MambaTabLSS",
"TabulaRNNClassifier",
"TabulaRNNRegressor",
"TabulaRNNLSS",
]
Loading

0 comments on commit 43d2758

Please sign in to comment.