diff --git a/.github/workflows/draft-pdf.yml b/.github/workflows/draft-pdf.yml
deleted file mode 100644
index 0b6a410..0000000
--- a/.github/workflows/draft-pdf.yml
+++ /dev/null
@@ -1,22 +0,0 @@
-on:
- push:
-
-jobs:
- paper:
- runs-on: ubuntu-latest
- name: Paper Draft
- steps:
- - name: Checkout
- uses: actions/checkout@v3
-
- - name: Build draft PDF
- uses: openjournals/openjournals-draft-action@master
- with:
- journal: joss
- paper-path: paper/paper.md
-
- - name: Upload
- uses: actions/upload-artifact@v3
- with:
- name: paper
- path: paper/paper.pdf
diff --git a/README.md b/README.md
index 735e24a..5c48117 100644
--- a/README.md
+++ b/README.md
@@ -1,263 +1,352 @@
-
-
-
-
-[![PyPI](https://img.shields.io/pypi/v/mambular)](https://pypi.org/project/mambular)
-![PyPI - Downloads](https://img.shields.io/pypi/dm/mambular)
-[![docs build](https://readthedocs.org/projects/mambular/badge/?version=latest)](https://mambular.readthedocs.io/en/latest/?badge=latest)
-[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mambular.readthedocs.io/en/latest/)
-[![open issues](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/basf/mamba-tabular/issues)
-
-
-[📘Documentation](https://mambular.readthedocs.io/en/latest/index.html) |
-[🛠️Installation](https://mambular.readthedocs.io/en/latest/installation.html) |
-[Models](https://mambular.readthedocs.io/en/latest/api/models/index.html) |
-[🤔Report Issues](https://github.com/basf/mamba-tabular/issues)
-
-
-# Mambular: Tabular Deep Learning with Mamba Architectures
-
-Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
-
-## Features
-
-- **Comprehensive Model Suite**: Includes modules for regression, classification, and distributional regression, catering to a wide range of tabular data tasks.
-- **State-of-the-Art Architectures**: Leverages various advanced architectures known for their effectiveness in handling tabular data. Mambular models include powerful Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) and can include bidirectional processing as well as feature interaction layers.
-- **Seamless Integration**: Designed to work effortlessly with scikit-learn, allowing for easy inclusion in existing machine learning pipelines, cross-validation, and hyperparameter tuning workflows.
-- **Extensive Preprocessing**: Comes with a powerful preprocessing module that supports a broad array of data transformation techniques, ensuring that your data is optimally prepared for model training.
-- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn.
-- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision.
-
-
-
-## Models
-
-| Model | Description |
-|---------------------|--------------------------------------------------------------------------------------------------|
-| `Mambular` | An advanced model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
-| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
-| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
-| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
-| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
-
-All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
-Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
-
-
-
-## Documentation
-
-You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/).
-
-## Installation
-
-Install Mambular using pip:
-```sh
-pip install mambular
-```
-
-## Preprocessing
-
-Mambular simplifies the preprocessing stage of model development with a comprehensive set of techniques to prepare your data for Mamba architectures. Our preprocessing module is designed to be both powerful and easy to use, offering a variety of options to efficiently transform your tabular data.
-
-### Data Type Detection and Transformation
-
-Mambular automatically identifies the type of each feature in your dataset and applies the most appropriate transformations for numerical and categorical variables. This includes:
-- **Ordinal Encoding**: Categorical features are seamlessly transformed into numerical values, preserving their inherent order and making them model-ready.
-- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality.
-- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts.
-- **Decision Tree Binning**: Optionally, Mambular can use decision trees to find the optimal binning strategy for numerical features, enhancing model interpretability and performance.
-- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible.
-- **Standardization**: Similarly, standardization instead of normalization can be used to scale features based on the mean and standard deviation.
-- **PLE (Periodic Linear Encoding)**: This technique can be applied to numerical features to enhance the performance of tabular deep learning methods by encoding periodicity.
-- **Quantile Transformation**: Numerical features can be transformed to follow a uniform or normal distribution, improving model robustness to outliers.
-- **Spline Transformation**: Applies piecewise polynomial functions to numerical features, capturing nonlinear relationships more effectively.
-- **Polynomial Features**: Generates polynomial and interaction features, increasing the feature space to capture more complex relationships within the data.
-
-
-
-### Handling Missing Values
-
-Our preprocessing pipeline effectively handles missing data by using mean imputation for numerical features and mode imputation for categorical features. This ensures that your models receive complete data inputs without needing manual intervention.
-Additionally, Mambular can manage unknown categorical values during inference by incorporating classical tokens in categorical preprocessing.
-
-
-## Fit a Model
-Fitting a model in mambular is as simple as it gets. All models in mambular are sklearn BaseEstimators. Thus the `.fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
-
-```python
-from mambular.models import MambularClassifier
-# Initialize and fit your model
-model = MambularClassifier(
- d_model=64,
- n_layers=8,
- numerical_preprocessing="ple",
- n_bins=50
-)
-
-# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
-model.fit(X, y, max_epochs=150, lr=1e-04)
-```
-
-Predictions are also easily obtained:
-```python
-# simple predictions
-preds = model.predict(X)
-
-# Predict probabilities
-preds = model.predict_proba(X)
-```
-
-
-## Distributional Regression with MambularLSS
-
-Mambular introduces an approach to distributional regression through its `MambularLSS` module, allowing users to model the full distribution of a response variable, not just its mean. This method is particularly valuable in scenarios where understanding the variability, skewness, or kurtosis of the response distribution is as crucial as predicting its central tendency. All available moedls in mambular are also available as distributional models.
-
-### Key Features of MambularLSS:
-
-- **Full Distribution Modeling**: Unlike traditional regression models that predict a single value (e.g., the mean), `MambularLSS` models the entire distribution of the response variable. This allows for more informative predictions, including quantiles, variance, and higher moments.
-- **Customizable Distribution Types**: `MambularLSS` supports a variety of distribution families (e.g., Gaussian, Poisson, Binomial), making it adaptable to different types of response variables, from continuous to count data.
-- **Location, Scale, Shape Parameters**: The model predicts parameters corresponding to the location, scale, and shape of the distribution, offering a nuanced understanding of the data's underlying distributional characteristics.
-- **Enhanced Predictive Uncertainty**: By modeling the full distribution, `MambularLSS` provides richer information on predictive uncertainty, enabling more robust decision-making processes in uncertain environments.
-
-
-
-### Available Distribution Classes:
-
-`MambularLSS` offers a wide range of distribution classes to cater to various statistical modeling needs. The available distribution classes include:
-
-- `normal`: Normal Distribution for modeling continuous data with a symmetric distribution around the mean.
-- `poisson`: Poisson Distribution for modeling count data that for instance represent the number of events occurring within a fixed interval.
-- `gamma`: Gamma Distribution for modeling continuous data that is skewed and bounded at zero, often used for waiting times.
-- `beta`: Beta Distribution for modeling data that is bounded between 0 and 1, useful for proportions and percentages.
-- `dirichlet`: Dirichlet Distribution for modeling multivariate data where individual components are correlated, and the sum is constrained to 1.
-- `studentt`: Student's T-Distribution for modeling data with heavier tails than the normal distribution, useful when the sample size is small.
-- `negativebinom`: Negative Binomial Distribution for modeling count data with over-dispersion relative to the Poisson distribution.
-- `inversegamma`: Inverse Gamma Distribution, often used as a prior distribution in Bayesian inference for scale parameters.
-- `categorical`: Categorical Distribution for modeling categorical data with more than two categories.
-
-These distribution classes allow `MambularLSS` to flexibly model a wide variety of data types and distributions, providing users with the tools needed to capture the full complexity of their data.
-
-
-### Getting Started with MambularLSS:
-
-To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other Mambular models:
-
-```python
-from mambular.models import MambularLSS
-
-# Initialize the MambularLSS model
-model = MambularLSS(
- dropout=0.2,
- d_model=64,
- n_layers=8,
-
-)
-
-# Fit the model to your data
-model.fit(
- X,
- y,
- max_epochs=150,
- lr=1e-04,
- patience=10,
- family="normal" # define your distribution
- )
-
-```
-
-
-### Implement Your Own Model
-
-Mambular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from Mambular's `BaseModel`. Each Mambular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
-
-One of the key advantages of using Mambular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
-
-Here's how you can implement a custom model with Mambular:
-
-
-1. First, define your config:
-The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.
-
-```python
-from dataclasses import dataclass
-
-@dataclass
-class MyConfig:
- lr: float = 1e-04
- lr_patience: int = 10
- weight_decay: float = 1e-06
- lr_factor: float = 0.1
-```
-
-2. Second, define your model:
-Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
-
-```python
-from mambular.base_models import BaseModel
-import torch
-import torch.nn
-
-class MyCustomModel(BaseModel):
- def __init__(
- self,
- cat_feature_info,
- num_feature_info,
- num_classes: int = 1,
- config=None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
-
- input_dim = 0
- for feature_name, input_shape in num_feature_info.items():
- input_dim += input_shape
- for feature_name, input_shape in cat_feature_info.items():
- input_dim += 1
-
- self.linear = nn.Linear(input_dim, num_classes)
-
- def forward(self, num_features, cat_features):
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
-
- # Pass through linear layer
- output = self.linear(x)
- return output
-```
-
-3. Leverage the Mambular API:
-You can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
-
-```python
-from mambular.models import SklearnBaseRegressor
-
-class MyRegressor(SklearnBaseRegressor):
- def __init__(self, **kwargs):
- super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
-```
-
-4. Train and evaluate your model:
-You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
-
-```python
-regressor = MyRegressor(numerical_preprocessing="ple")
-regressor.fit(X_train, y_train, max_epochs=50)
-```
-
-## Citation
-
-If you find this project useful in your research, please consider cite:
-```BibTeX
-@misc{2024,
- title={Mambular: Tabular Deep Learning with Mamba Architectures},
- author={Anton Frederik Thielmann, Manish Kumar, Christoph Weisser, Benjamin Saefken, Soheila Samiee},
- howpublished = {\url{https://github.com/basf/mamba-tabular}},
- year={2024}
-}
-```
-
-## License
-
-The entire codebase is under MIT license.
+
+
+
+
+[![PyPI](https://img.shields.io/pypi/v/mambular)](https://pypi.org/project/mambular)
+![PyPI - Downloads](https://img.shields.io/pypi/dm/mambular)
+[![docs build](https://readthedocs.org/projects/mambular/badge/?version=latest)](https://mambular.readthedocs.io/en/latest/?badge=latest)
+[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mambular.readthedocs.io/en/latest/)
+[![open issues](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/basf/mamba-tabular/issues)
+
+
+[📘Documentation](https://mambular.readthedocs.io/en/latest/index.html) |
+[🛠️Installation](https://mambular.readthedocs.io/en/latest/installation.html) |
+[Models](https://mambular.readthedocs.io/en/latest/api/models/index.html) |
+[🤔Report Issues](https://github.com/basf/mamba-tabular/issues)
+
+
+# Mambular: Tabular Deep Learning with Mamba Architectures
+
+You can find our paper `Mambular: A Sequential Model for Tabular Deep Learning` [here](https://arxiv.org/abs/2408.06291).
+
+# Table of Contents
+- [Mambular: Tabular Deep Learning with Mamba Architectures](#mambular-tabular-deep-learning-with-mamba-architectures)
+- [Table of Contents](#table-of-contents)
+ - [Introduction](#introduction)
+ - [Features](#features)
+ - [Models](#models)
+ - [Results](#results)
+ - [Documentation](#documentation)
+ - [Installation](#installation)
+- [Usage](#usage)
+ - [Preprocessing](#preprocessing)
+ - [Data Type Detection and Transformation](#data-type-detection-and-transformation)
+ - [Handling Missing Values](#handling-missing-values)
+ - [Fit a Model](#fit-a-model)
+ - [Distributional Regression with MambularLSS](#distributional-regression-with-mambularlss)
+ - [Key Features of MambularLSS:](#key-features-of-mambularlss)
+ - [Available Distribution Classes:](#available-distribution-classes)
+ - [Getting Started with MambularLSS:](#getting-started-with-mambularlss)
+ - [Implement Your Own Model](#implement-your-own-model)
+ - [Citation](#citation)
+ - [License](#license)
+
+
+
+## Introduction
+Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's `BaseEstimator` interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
+
+## Features
+
+- **Comprehensive Model Suite**: Includes modules for regression, classification, and distributional regression, catering to a wide range of tabular data tasks.
+- **State-of-the-Art Architectures**: Leverages various advanced architectures known for their effectiveness in handling tabular data. Mambular models include powerful Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) and can include bidirectional processing as well as feature interaction layers.
+- **Seamless Integration**: Designed to work effortlessly with scikit-learn, allowing for easy inclusion in existing machine learning pipelines, cross-validation, and hyperparameter tuning workflows.
+- **Extensive Preprocessing**: Comes with a powerful preprocessing module that supports a broad array of data transformation techniques, ensuring that your data is optimally prepared for model training.
+- **Sklearn-like API**: The familiar scikit-learn `fit`, `predict`, and `predict_proba` methods mean minimal learning curve for those already accustomed to scikit-learn.
+- **PyTorch Lightning Under the Hood**: Built on top of PyTorch Lightning, Mambular models benefit from streamlined training processes, easy customization, and advanced features like distributed training and 16-bit precision.
+
+
+## Models
+
+| Model | Description |
+| ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `Mambular` | A sequential model using Mamba blocks [Gu and Dao](https://arxiv.org/pdf/2312.00752) specifically designed for various tabular data tasks. |
+| `FTTransformer` | A model leveraging transformer encoders, as introduced by [Gorishniy et al.](https://arxiv.org/abs/2106.11959), for tabular data. |
+| `MLP` | A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
+| `ResNet` | An adaptation of the ResNet architecture for tabular data applications. |
+| `TabTransformer` | A transformer-based model for tabular data introduced by [Huang et al.](https://arxiv.org/abs/2012.06678), enhancing feature learning capabilities. |
+| `MambaTab` | A tabular model using a Mamba-Block on a joint input representation described [here](https://arxiv.org/abs/2401.08867) . Not a sequential model. |
+| `TabulaRNN` | A Recurrent Neural Network for Tabular data. Not yet included in the benchmarks |
+
+
+All models are available for `regression`, `classification` and distributional regression, denoted by `LSS`.
+Hence, they are available as e.g. `MambularRegressor`, `MambularClassifier` or `MambularLSS`
+
+
+## Results
+Detailed results for the available methods can be found [here](https://arxiv.org/abs/2408.06291).
+Note, that these are achieved results with default hyperparameter and for our splits. Performing hyperparameter optimization could improve the performance of all models.
+
+The average rank table over all models and all datasets is given here:
+
+
+
+
+
+ Model |
+ Avg. Rank |
+
+
+ Mambular |
+ 2.083 ±1.037 |
+
+
+ FT-Transformer |
+ 2.417 ±1.256 |
+
+
+ XGBoost |
+ 3.167 ±2.577 |
+
+
+ MambaTab* |
+ 4.333 ±1.374 |
+
+
+ ResNet |
+ 4.750 ±1.639 |
+
+
+ TabTransformer |
+ 6.222 ±1.618 |
+
+
+ MLP |
+ 6.500 ±1.500 |
+
+
+ MambaTab |
+ 6.583 ±1.801 |
+
+
+ MambaTabT |
+ 7.917 ±1.187 |
+
+
+
+
+
+
+
+
+
+## Documentation
+
+You can find the Mamba-Tabular API documentation [here](https://mambular.readthedocs.io/en/latest/).
+
+## Installation
+
+Install Mambular using pip:
+```sh
+pip install mambular
+```
+
+# Usage
+## Preprocessing
+
+Mambular simplifies the preprocessing stage of model development with a comprehensive set of techniques to prepare your data for Mamba architectures. Our preprocessing module is designed to be both powerful and easy to use, offering a variety of options to efficiently transform your tabular data.
+
+### Data Type Detection and Transformation
+
+Mambular automatically identifies the type of each feature in your dataset and applies the most appropriate transformations for numerical and categorical variables. This includes:
+- **Ordinal Encoding**: Categorical features are seamlessly transformed into numerical values, preserving their inherent order and making them model-ready.
+- **One-Hot Encoding**: For nominal data, Mambular employs one-hot encoding to capture the presence or absence of categories without imposing ordinality.
+- **Binning**: Numerical features can be discretized into bins, a useful technique for handling continuous variables in certain modeling contexts.
+- **Decision Tree Binning**: Optionally, Mambular can use decision trees to find the optimal binning strategy for numerical features, enhancing model interpretability and performance.
+- **Normalization**: Mambular can easily handle numerical features without specifically turning them into categorical features. Standard preprocessing steps such as normalization per feature are possible.
+- **Standardization**: Similarly, standardization instead of normalization can be used to scale features based on the mean and standard deviation.
+- **PLE (Periodic Linear Encoding)**: This technique can be applied to numerical features to enhance the performance of tabular deep learning methods by encoding periodicity.
+- **Quantile Transformation**: Numerical features can be transformed to follow a uniform or normal distribution, improving model robustness to outliers.
+- **Spline Transformation**: Applies piecewise polynomial functions to numerical features, capturing nonlinear relationships more effectively.
+- **Polynomial Features**: Generates polynomial and interaction features, increasing the feature space to capture more complex relationships within the data.
+
+
+
+### Handling Missing Values
+
+Our preprocessing pipeline effectively handles missing data by using mean imputation for numerical features and mode imputation for categorical features. This ensures that your models receive complete data inputs without needing manual intervention.
+Additionally, Mambular can manage unknown categorical values during inference by incorporating classical tokens in categorical preprocessing.
+
+
+## Fit a Model
+Fitting a model in mambular is as simple as it gets. All models in mambular are sklearn BaseEstimators. Thus the `.fit` method is implemented for all of them. Additionally, this allows for using all other sklearn inherent methods such as their built in hyperparameter optimization tools.
+
+```python
+from mambular.models import MambularClassifier
+# Initialize and fit your model
+model = MambularClassifier(
+ d_model=64,
+ n_layers=8,
+ numerical_preprocessing="ple",
+ n_bins=50
+)
+
+# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
+model.fit(X, y, max_epochs=150, lr=1e-04)
+```
+
+Predictions are also easily obtained:
+```python
+# simple predictions
+preds = model.predict(X)
+
+# Predict probabilities
+preds = model.predict_proba(X)
+```
+
+
+## Distributional Regression with MambularLSS
+
+Mambular introduces an approach to distributional regression through its `MambularLSS` module, allowing users to model the full distribution of a response variable, not just its mean. This method is particularly valuable in scenarios where understanding the variability, skewness, or kurtosis of the response distribution is as crucial as predicting its central tendency. All available moedls in mambular are also available as distributional models.
+
+### Key Features of MambularLSS:
+
+- **Full Distribution Modeling**: Unlike traditional regression models that predict a single value (e.g., the mean), `MambularLSS` models the entire distribution of the response variable. This allows for more informative predictions, including quantiles, variance, and higher moments.
+- **Customizable Distribution Types**: `MambularLSS` supports a variety of distribution families (e.g., Gaussian, Poisson, Binomial), making it adaptable to different types of response variables, from continuous to count data.
+- **Location, Scale, Shape Parameters**: The model predicts parameters corresponding to the location, scale, and shape of the distribution, offering a nuanced understanding of the data's underlying distributional characteristics.
+- **Enhanced Predictive Uncertainty**: By modeling the full distribution, `MambularLSS` provides richer information on predictive uncertainty, enabling more robust decision-making processes in uncertain environments.
+
+
+
+### Available Distribution Classes:
+
+`MambularLSS` offers a wide range of distribution classes to cater to various statistical modeling needs. The available distribution classes include:
+
+- `normal`: Normal Distribution for modeling continuous data with a symmetric distribution around the mean.
+- `poisson`: Poisson Distribution for modeling count data that for instance represent the number of events occurring within a fixed interval.
+- `gamma`: Gamma Distribution for modeling continuous data that is skewed and bounded at zero, often used for waiting times.
+- `beta`: Beta Distribution for modeling data that is bounded between 0 and 1, useful for proportions and percentages.
+- `dirichlet`: Dirichlet Distribution for modeling multivariate data where individual components are correlated, and the sum is constrained to 1.
+- `studentt`: Student's T-Distribution for modeling data with heavier tails than the normal distribution, useful when the sample size is small.
+- `negativebinom`: Negative Binomial Distribution for modeling count data with over-dispersion relative to the Poisson distribution.
+- `inversegamma`: Inverse Gamma Distribution, often used as a prior distribution in Bayesian inference for scale parameters.
+- `categorical`: Categorical Distribution for modeling categorical data with more than two categories.
+
+These distribution classes allow `MambularLSS` to flexibly model a wide variety of data types and distributions, providing users with the tools needed to capture the full complexity of their data.
+
+
+### Getting Started with MambularLSS:
+
+To integrate distributional regression into your workflow with `MambularLSS`, start by initializing the model with your desired configuration, similar to other Mambular models:
+
+```python
+from mambular.models import MambularLSS
+
+# Initialize the MambularLSS model
+model = MambularLSS(
+ dropout=0.2,
+ d_model=64,
+ n_layers=8,
+
+)
+
+# Fit the model to your data
+model.fit(
+ X,
+ y,
+ max_epochs=150,
+ lr=1e-04,
+ patience=10,
+ family="normal" # define your distribution
+ )
+
+```
+
+
+### Implement Your Own Model
+
+Mambular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from `nn.Module`, you inherit from Mambular's `BaseModel`. Each Mambular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), `cat_feature_info`, and `num_feature_info` for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
+
+One of the key advantages of using Mambular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
+
+Here's how you can implement a custom model with Mambular:
+
+
+1. First, define your config:
+The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.
+
+```python
+from dataclasses import dataclass
+
+@dataclass
+class MyConfig:
+ lr: float = 1e-04
+ lr_patience: int = 10
+ weight_decay: float = 1e-06
+ lr_factor: float = 0.1
+```
+
+2. Second, define your model:
+Define your custom model just as you would for an `nn.Module`. The main difference is that you will inherit from `BaseModel` and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.
+
+```python
+from mambular.base_models import BaseModel
+import torch
+import torch.nn
+
+class MyCustomModel(BaseModel):
+ def __init__(
+ self,
+ cat_feature_info,
+ num_feature_info,
+ num_classes: int = 1,
+ config=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+
+ input_dim = 0
+ for feature_name, input_shape in num_feature_info.items():
+ input_dim += input_shape
+ for feature_name, input_shape in cat_feature_info.items():
+ input_dim += 1
+
+ self.linear = nn.Linear(input_dim, num_classes)
+
+ def forward(self, num_features, cat_features):
+ x = num_features + cat_features
+ x = torch.cat(x, dim=1)
+
+ # Pass through linear layer
+ output = self.linear(x)
+ return output
+```
+
+3. Leverage the Mambular API:
+You can build a regression, classification or distributional regression model that can leverage all of mambulars built-in methods, by using the following:
+
+```python
+from mambular.models import SklearnBaseRegressor
+
+class MyRegressor(SklearnBaseRegressor):
+ def __init__(self, **kwargs):
+ super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
+```
+
+4. Train and evaluate your model:
+You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit from `SklearnBaseClassifier` or `SklearnBaseLSS` respectively.
+
+```python
+regressor = MyRegressor(numerical_preprocessing="ple")
+regressor.fit(X_train, y_train, max_epochs=50)
+```
+
+## Citation
+
+If you find this project useful in your research, please consider cite:
+```BibTeX
+@misc{thielmann2024mambularsequentialmodeltabular,
+ title={Mambular: A Sequential Model for Tabular Deep Learning},
+ author={Anton Frederik Thielmann and Manish Kumar and Christoph Weisser and Arik Reuter and Benjamin Säfken and Soheila Samiee},
+ year={2024},
+ eprint={2408.06291},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG},
+ url={https://arxiv.org/abs/2408.06291},
+}
+```
+
+## License
+
+The entire codebase is under MIT license.
diff --git a/mambular/__version__.py b/mambular/__version__.py
index ba9f224..36509b4 100644
--- a/mambular/__version__.py
+++ b/mambular/__version__.py
@@ -1,4 +1,4 @@
"""Version information."""
# The following line *must* be the last in the module, exactly as formatted:
-__version__ = "0.1.7"
+__version__ = "0.2.1"
diff --git a/mambular/arch_utils/attention_net_arch_utils.py b/mambular/arch_utils/attention_net_arch_utils.py
new file mode 100644
index 0000000..53d9060
--- /dev/null
+++ b/mambular/arch_utils/attention_net_arch_utils.py
@@ -0,0 +1,102 @@
+import torch.nn as nn
+import torch
+
+
+import torch
+import torch.nn as nn
+
+
+class Reshape(nn.Module):
+ def __init__(self, j, dim, method="linear"):
+ super(Reshape, self).__init__()
+ self.j = j
+ self.dim = dim
+ self.method = method
+
+ if self.method == "linear":
+ # Use nn.Linear approach
+ self.layer = nn.Linear(dim, j * dim)
+ elif self.method == "embedding":
+ # Use nn.Embedding approach
+ self.layer = nn.Embedding(dim, j * dim)
+ elif self.method == "conv1d":
+ # Use nn.Conv1d approach
+ self.layer = nn.Conv1d(in_channels=dim, out_channels=j * dim, kernel_size=1)
+ else:
+ raise ValueError(f"Unsupported method '{method}' for reshaping.")
+
+ def forward(self, x):
+ batch_size = x.shape[0]
+
+ if self.method == "linear" or self.method == "embedding":
+ x_reshaped = self.layer(x) # shape: (batch_size, j * dim)
+ x_reshaped = x_reshaped.view(
+ batch_size, self.j, self.dim
+ ) # shape: (batch_size, j, dim)
+ elif self.method == "conv1d":
+ # For Conv1d, add dummy dimension and reshape
+ x = x.unsqueeze(-1) # Add dummy dimension for convolution
+ x_reshaped = self.layer(x) # shape: (batch_size, j * dim, 1)
+ x_reshaped = x_reshaped.squeeze(-1) # Remove dummy dimension
+ x_reshaped = x_reshaped.view(
+ batch_size, self.j, self.dim
+ ) # shape: (batch_size, j, dim)
+
+ return x_reshaped
+
+
+class AttentionNetBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ in_channels,
+ d_model,
+ n_heads,
+ n_layers,
+ dim_feedforward,
+ transformer_activation,
+ output_dim,
+ attn_dropout,
+ layer_norm_eps,
+ norm_first,
+ bias,
+ activation,
+ embedding_activation,
+ norm_f,
+ method,
+ ):
+ super(AttentionNetBlock, self).__init__()
+
+ self.reshape = Reshape(channels, in_channels, method)
+
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=n_heads,
+ batch_first=True,
+ dim_feedforward=dim_feedforward,
+ dropout=attn_dropout,
+ activation=transformer_activation,
+ layer_norm_eps=layer_norm_eps,
+ norm_first=norm_first,
+ bias=bias,
+ )
+
+ self.encoder = nn.TransformerEncoder(
+ encoder_layer,
+ num_layers=n_layers,
+ norm=norm_f,
+ )
+
+ self.linear = nn.Linear(d_model, output_dim)
+ self.activation = activation
+ self.embedding_activation = embedding_activation
+
+ def forward(self, x):
+ z = self.reshape(x)
+ x = self.embedding_activation(z)
+ x = self.encoder(x)
+ x = z + x
+ x = torch.sum(x, dim=1)
+ x = self.linear(x)
+ x = self.activation(x)
+ return x
diff --git a/mambular/arch_utils/attention_utils.py b/mambular/arch_utils/attention_utils.py
new file mode 100644
index 0000000..efbe74e
--- /dev/null
+++ b/mambular/arch_utils/attention_utils.py
@@ -0,0 +1,97 @@
+import torch.nn as nn
+import torch
+from rotary_embedding_torch import RotaryEmbedding
+from einops import rearrange
+import torch.nn.functional as F
+import numpy as np
+
+
+class GEGLU(nn.Module):
+ def forward(self, x):
+ x, gates = x.chunk(2, dim=-1)
+ return x * F.gelu(gates)
+
+
+def FeedForward(dim, mult=4, dropout=0.0):
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * mult * 2),
+ GEGLU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim * mult, dim),
+ )
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
+ super().__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ self.norm = nn.LayerNorm(dim)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+ self.rotary = rotary
+ dim = np.int64(dim / 2)
+ self.rotary_embedding = RotaryEmbedding(dim=dim)
+
+ def forward(self, x):
+ h = self.heads
+ x = self.norm(x)
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+ if self.rotary:
+ q = self.rotary_embedding.rotate_queries_or_keys(q)
+ k = self.rotary_embedding.rotate_queries_or_keys(k)
+ q = q * self.scale
+
+ sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
+
+ attn = sim.softmax(dim=-1)
+ dropped_attn = self.dropout(attn)
+
+ out = torch.einsum("b h i j, b h j d -> b h i d", dropped_attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+ out = self.to_out(out)
+
+ return out, attn
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ Attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=attn_dropout,
+ rotary=rotary,
+ ),
+ FeedForward(dim, dropout=ff_dropout),
+ ]
+ )
+ )
+
+ def forward(self, x, return_attn=False):
+ post_softmax_attns = []
+
+ for attn, ff in self.layers:
+ attn_out, post_softmax_attn = attn(x)
+ post_softmax_attns.append(post_softmax_attn)
+
+ x = attn_out + x
+ x = ff(x) + x
+
+ if not return_attn:
+ return x
+
+ return x, torch.stack(post_softmax_attns)
diff --git a/mambular/arch_utils/embedding_layer.py b/mambular/arch_utils/embedding_layer.py
new file mode 100644
index 0000000..43fe453
--- /dev/null
+++ b/mambular/arch_utils/embedding_layer.py
@@ -0,0 +1,163 @@
+import torch
+import torch.nn as nn
+
+
+class EmbeddingLayer(nn.Module):
+ def __init__(
+ self,
+ num_feature_info,
+ cat_feature_info,
+ d_model,
+ embedding_activation=nn.Identity(),
+ layer_norm_after_embedding=False,
+ use_cls=False,
+ cls_position=0,
+ cat_encoding="int",
+ ):
+ """
+ Embedding layer that handles numerical and categorical embeddings.
+
+ Parameters
+ ----------
+ num_feature_info : dict
+ Dictionary where keys are numerical feature names and values are their respective input dimensions.
+ cat_feature_info : dict
+ Dictionary where keys are categorical feature names and values are the number of categories for each feature.
+ d_model : int
+ Dimensionality of the embeddings.
+ embedding_activation : nn.Module, optional
+ Activation function to apply after embedding. Default is `nn.Identity()`.
+ layer_norm_after_embedding : bool, optional
+ If True, applies layer normalization after embeddings. Default is `False`.
+ use_cls : bool, optional
+ If True, includes a class token in the embeddings. Default is `False`.
+ cls_position : int, optional
+ Position to place the class token, either at the start (0) or end (1) of the sequence. Default is `0`.
+
+ Methods
+ -------
+ forward(num_features=None, cat_features=None)
+ Defines the forward pass of the model.
+ """
+ super(EmbeddingLayer, self).__init__()
+
+ self.d_model = d_model
+ self.embedding_activation = embedding_activation
+ self.layer_norm_after_embedding = layer_norm_after_embedding
+ self.use_cls = use_cls
+ self.cls_position = cls_position
+
+ self.num_embeddings = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Linear(input_shape, d_model, bias=False),
+ self.embedding_activation,
+ )
+ for feature_name, input_shape in num_feature_info.items()
+ ]
+ )
+
+ self.cat_embeddings = nn.ModuleList()
+ for feature_name, num_categories in cat_feature_info.items():
+ if cat_encoding == "int":
+ self.cat_embeddings.append(
+ nn.Sequential(
+ nn.Embedding(num_categories + 1, d_model),
+ self.embedding_activation,
+ )
+ )
+ elif cat_encoding == "one-hot":
+ self.cat_embeddings.append(
+ nn.Sequential(
+ OneHotEncoding(num_categories),
+ nn.Linear(num_categories, d_model, bias=False),
+ self.embedding_activation,
+ )
+ )
+
+ if self.use_cls:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
+ if layer_norm_after_embedding:
+ self.embedding_norm = nn.LayerNorm(d_model)
+
+ self.seq_len = len(self.num_embeddings) + len(self.cat_embeddings)
+
+ def forward(self, num_features=None, cat_features=None):
+ """
+ Defines the forward pass of the model.
+
+ Parameters
+ ----------
+ num_features : Tensor, optional
+ Tensor containing the numerical features.
+ cat_features : Tensor, optional
+ Tensor containing the categorical features.
+
+ Returns
+ -------
+ Tensor
+ The output embeddings of the model.
+
+ Raises
+ ------
+ ValueError
+ If no features are provided to the model.
+ """
+ if self.use_cls:
+ batch_size = (
+ cat_features[0].size(0)
+ if cat_features != []
+ else num_features[0].size(0)
+ )
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+
+ if self.cat_embeddings and cat_features is not None:
+ cat_embeddings = [
+ emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
+ ]
+ cat_embeddings = torch.stack(cat_embeddings, dim=1)
+ cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
+ if self.layer_norm_after_embedding:
+ cat_embeddings = self.embedding_norm(cat_embeddings)
+ else:
+ cat_embeddings = None
+
+ if self.num_embeddings and num_features is not None:
+ num_embeddings = [
+ emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
+ ]
+ num_embeddings = torch.stack(num_embeddings, dim=1)
+ if self.layer_norm_after_embedding:
+ num_embeddings = self.embedding_norm(num_embeddings)
+ else:
+ num_embeddings = None
+
+ if cat_embeddings is not None and num_embeddings is not None:
+ x = torch.cat([cat_embeddings, num_embeddings], dim=1)
+ elif cat_embeddings is not None:
+ x = cat_embeddings
+ elif num_embeddings is not None:
+ x = num_embeddings
+ else:
+ raise ValueError("No features provided to the model.")
+
+ if self.use_cls:
+ if self.cls_position == 0:
+ x = torch.cat([cls_tokens, x], dim=1)
+ elif self.cls_position == 1:
+ x = torch.cat([x, cls_tokens], dim=1)
+ else:
+ raise ValueError(
+ "Invalid cls_position value. It should be either 0 or 1."
+ )
+
+ return x
+
+
+class OneHotEncoding(nn.Module):
+ def __init__(self, num_categories):
+ super(OneHotEncoding, self).__init__()
+ self.num_categories = num_categories
+
+ def forward(self, x):
+ return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()
diff --git a/mambular/arch_utils/learnable_ple.py b/mambular/arch_utils/learnable_ple.py
new file mode 100644
index 0000000..9e53a19
--- /dev/null
+++ b/mambular/arch_utils/learnable_ple.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+
+
+class PeriodicLinearEncodingLayer(nn.Module):
+ def __init__(self, bins=10, learn_bins=True):
+ super(PeriodicLinearEncodingLayer, self).__init__()
+ self.bins = bins
+ self.learn_bins = learn_bins
+
+ if self.learn_bins:
+ # Learnable bin boundaries
+ self.bin_boundaries = nn.Parameter(torch.linspace(0, 1, self.bins + 1))
+ else:
+ self.bin_boundaries = torch.linspace(-1, 1, self.bins + 1)
+
+ def forward(self, x):
+ if self.learn_bins:
+ # Ensure bin boundaries are sorted
+ sorted_bins = torch.sort(self.bin_boundaries)[0]
+ else:
+ sorted_bins = self.bin_boundaries
+
+ # Initialize z with zeros
+ z = torch.zeros(x.size(0), self.bins, device=x.device)
+
+ for t in range(1, self.bins + 1):
+ b_t_1 = sorted_bins[t - 1]
+ b_t = sorted_bins[t]
+ mask1 = x < b_t_1
+ mask2 = x >= b_t
+ mask3 = (x >= b_t_1) & (x < b_t)
+
+ z[mask1.squeeze(), t - 1] = 0
+ z[mask2.squeeze(), t - 1] = 1
+ z[mask3.squeeze(), t - 1] = (x[mask3] - b_t_1) / (b_t - b_t_1)
+
+ return z
diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_arch.py
index 3db39ed..537b8e5 100644
--- a/mambular/arch_utils/mamba_arch.py
+++ b/mambular/arch_utils/mamba_arch.py
@@ -43,6 +43,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
+ layer_norm_eps=1e-05,
+ AD_weight_decay=False,
+ BC_layer_norm=True,
):
super().__init__()
@@ -66,6 +69,9 @@ def __init__(
activation,
bidirectional,
use_learnable_interaction,
+ layer_norm_eps,
+ AD_weight_decay,
+ BC_layer_norm,
)
for _ in range(n_layers)
]
@@ -105,6 +111,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
+ layer_norm_eps=1e-05,
+ AD_weight_decay=False,
+ BC_layer_norm=False,
):
super().__init__()
@@ -149,8 +158,11 @@ def __init__(
activation=activation,
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
+ layer_norm_eps=layer_norm_eps,
+ AD_weight_decay=AD_weight_decay,
+ BC_layer_norm=BC_layer_norm,
)
- self.norm = norm(d_model)
+ self.norm = norm(d_model, eps=layer_norm_eps)
def forward(self, x):
output = self.layers(self.norm(x)) + x
@@ -189,6 +201,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
+ layer_norm_eps=1e-05,
+ AD_weight_decay=False,
+ BC_layer_norm=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
@@ -239,6 +254,7 @@ def __init__(
elif dt_init == "random":
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
if self.bidirectional:
+
nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
@@ -262,17 +278,35 @@ def __init__(
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log_fwd = nn.Parameter(torch.log(A))
+ self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
+
if self.bidirectional:
self.A_log_bwd = nn.Parameter(torch.log(A))
+ self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
+
+ if not AD_weight_decay:
+ self.A_log_fwd._no_weight_decay = True
+ self.D_fwd._no_weight_decay = True
- self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
if self.bidirectional:
- self.D_bwd = nn.Parameter(torch.ones(self.d_inner))
+
+ if not AD_weight_decay:
+ self.A_log_bwd._no_weight_decay = True
+ self.D_bwd._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state
+ if BC_layer_norm:
+ self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
+ self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
+ self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
+ else:
+ self.dt_layernorm = None
+ self.B_layernorm = None
+ self.C_layernorm = None
+
def forward(self, x):
_, L, _ = x.shape
@@ -316,6 +350,15 @@ def forward(self, x):
return output
+ def _apply_layernorms(self, dt, B, C):
+ if self.dt_layernorm is not None:
+ dt = self.dt_layernorm(dt)
+ if self.B_layernorm is not None:
+ B = self.B_layernorm(B)
+ if self.C_layernorm is not None:
+ C = self.C_layernorm(C)
+ return dt, B, C
+
def ssm(self, x, forward=True):
if forward:
A = -torch.exp(self.A_log_fwd.float())
@@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
+ delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_fwd(delta))
else:
A = -torch.exp(self.A_log_bwd.float())
@@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
+ delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_bwd(delta))
y = self.selective_scan_seq(x, delta, A, B, C, D)
diff --git a/mambular/arch_utils/poly_layer.py b/mambular/arch_utils/poly_layer.py
new file mode 100644
index 0000000..6d71157
--- /dev/null
+++ b/mambular/arch_utils/poly_layer.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+from sklearn.preprocessing import PolynomialFeatures, MinMaxScaler, SplineTransformer
+import numpy as np
+
+
+class ScaledPolynomialLayer(nn.Module):
+ def __init__(self, degree=2):
+ super(ScaledPolynomialLayer, self).__init__()
+ self.degree = degree
+
+ # Initialize polynomial feature generator
+ self.poly = PolynomialFeatures(degree=self.degree, include_bias=False)
+ # Initialize learnable scaling parameter
+ self.weights = nn.Parameter(torch.ones(self.degree))
+
+ def forward(self, x):
+ # Scale the input to the range [-1, 1]
+ x_np = x.detach().cpu().numpy()
+ scaler = MinMaxScaler(feature_range=(-1, 1))
+ x_scaled = scaler.fit_transform(x_np) * 1e-05
+
+ # Generate polynomial features
+ poly_features = self.poly.fit_transform(x_scaled)
+
+ # Convert polynomial features back to tensor
+ poly_features = torch.tensor(poly_features, dtype=torch.float32).to(x.device)
+
+ # Apply the learnable scaling parameter
+ output = poly_features * self.weights
+
+ output = torch.clamp(output, min=-1e5, max=1e3)
+
+ return output
+
+
+class ScaledSplineLayer(nn.Module):
+ def __init__(self, degree=3, knots=63, learn_knots=True, learn_weights=True):
+ super(ScaledSplineLayer, self).__init__()
+ self.degree = degree
+ self.knots = knots
+ self.learn_knots = learn_knots
+ self.learn_weights = learn_weights
+
+ # Initialize polynomial feature generator
+ self.spline = SplineTransformer(
+ degree=self.degree, include_bias=False, knots=self.knots
+ )
+
+ if self.learn_knots:
+ # Learnable knots
+ self.knots_positions = nn.Parameter(torch.linspace(-1, 1, self.knots))
+ else:
+ self.knots_positions = torch.linspace(-1, 1, self.knots)
+
+ if self.learn_weights:
+ # Learnable weights for each dimension
+ self.weights = nn.Parameter(torch.ones(self.knots + 1))
+ else:
+ self.weights = torch.ones(self.knots + 1)
+
+ def forward(self, x):
+ # Scale the input to the range [-1, 1]
+ x_np = x.detach().cpu().numpy()
+ scaler = MinMaxScaler(feature_range=(-1, 1))
+ x_scaled = scaler.fit_transform(x_np).reshape(-1, 1)
+
+ if self.learn_knots:
+ # Use learnable knots positions and ensure they are sorted
+ knots = self.knots_positions.detach().cpu().numpy().reshape(-1, 1)
+ sorted_knots = np.sort(knots)
+ self.spline.knots = np.interp(
+ sorted_knots, (sorted_knots.min(), sorted_knots.max()), (0, 1)
+ )
+
+ # Generate spline features
+ spline_features = self.spline.fit_transform(x_scaled)
+
+ # Convert spline features back to tensor
+ spline_features = torch.tensor(spline_features, dtype=torch.float32).to(
+ x.device
+ )
+
+ # Apply the learnable scaling parameter and weights
+ output = spline_features * self.weights
+
+ return output
diff --git a/mambular/arch_utils/rotary_utils.py b/mambular/arch_utils/rotary_utils.py
new file mode 100644
index 0000000..a9bef25
--- /dev/null
+++ b/mambular/arch_utils/rotary_utils.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn as nn
+from rotary_embedding_torch import RotaryEmbedding
+import numpy as np
+from einops import rearrange
+
+
+class RotaryEmbeddingLayer(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.rotary_embedding = RotaryEmbedding(dim=dim)
+
+ def forward(self, q, k):
+ q = self.rotary_embedding.rotate_queries_or_keys(q)
+ k = self.rotary_embedding.rotate_queries_or_keys(k)
+ return q, k
+
+
+class RotaryTransformerEncoderLayer(nn.TransformerEncoderLayer):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ dropout=0.1,
+ activation=nn.SELU(),
+ layer_norm_eps=1e-5,
+ norm_first=False,
+ bias=True,
+ batch_first=False,
+ **kwargs,
+ ):
+ super().__init__(
+ d_model,
+ nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ layer_norm_eps=layer_norm_eps,
+ norm_first=norm_first,
+ batch_first=batch_first,
+ bias=bias,
+ **kwargs,
+ )
+ self.rotary_embedding = RotaryEmbeddingLayer(dim=d_model // nhead)
+ self.nhead = nhead
+ self.d_model = d_model
+
+ def _sa_block(self, x, attn_mask, key_padding_mask):
+ # Multi-head attention with rotary embedding
+ device = x.device
+ batch_size, seq_length, d_model = x.size()
+ head_dim = d_model // self.nhead
+ qkv = nn.Linear(d_model, d_model * 3, bias=False).to(device)(x)
+ q, k, v = qkv.chunk(3, dim=-1)
+ q, k, v = map(
+ lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q, k, v)
+ )
+
+ # Apply rotary embeddings to queries and keys
+ q, k = self.rotary_embedding(q, k)
+
+ q = q * (head_dim**-0.5)
+ sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
+ if attn_mask is not None:
+ sim = sim.masked_fill(attn_mask == 0, float("-inf"))
+ attn = sim.softmax(dim=-1)
+ if self.training:
+ attn = self.dropout(attn)
+
+ out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return nn.Linear(d_model, d_model, bias=False).to(device)(out)
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
+ # Pre-norm if required
+ device = src.device
+ if self.norm_first:
+ src = self.norm1(src)
+ src2 = self._sa_block(src, src_mask, src_key_padding_mask).to(device)
+ src = src + self.dropout1(src2)
+ src = self.norm2(src)
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+ src = src + self.dropout2(src2)
+ else:
+ src2 = self._sa_block(self.norm1(src), src_mask, src_key_padding_mask).to(
+ device
+ )
+ src = src + self.dropout1(src2)
+ src2 = self.linear2(
+ self.dropout(self.activation(self.linear1(self.norm2(src))))
+ )
+ src = src + self.dropout2(src2)
+
+ return src
+
+
+class RotaryTransformerEncoder(nn.TransformerEncoder):
+ def __init__(
+ self,
+ encoder_layer,
+ num_layers,
+ norm=None,
+ ):
+ super().__init__(
+ encoder_layer,
+ num_layers,
+ norm=norm,
+ )
+
+ def forward(self, src, mask=None, src_key_padding_mask=None):
+ return super().forward(src, mask, src_key_padding_mask)
diff --git a/mambular/arch_utils/transformer_utils.py b/mambular/arch_utils/transformer_utils.py
new file mode 100644
index 0000000..c4aaf6b
--- /dev/null
+++ b/mambular/arch_utils/transformer_utils.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def reglu(x):
+ a, b = x.chunk(2, dim=-1)
+ return a * F.relu(b)
+
+
+class ReGLU(nn.Module):
+ def forward(self, x):
+ return reglu(x)
+
+
+class GLU(nn.Module):
+ def __init__(self):
+ super(GLU, self).__init__()
+
+ def forward(self, x):
+ assert x.size(-1) % 2 == 0, "Input dimension must be even"
+ split_dim = x.size(-1) // 2
+ return x[..., :split_dim] * torch.sigmoid(x[..., split_dim:])
+
+
+class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
+ def __init__(self, *args, activation=F.relu, **kwargs):
+ super(CustomTransformerEncoderLayer, self).__init__(
+ *args, activation=activation, **kwargs
+ )
+ self.custom_activation = activation
+
+ # Check if the activation function is an instance of a GLU variant
+ if activation in [ReGLU, GLU] or isinstance(activation, (ReGLU, GLU)):
+ self.linear1 = nn.Linear(
+ self.linear1.in_features,
+ self.linear1.out_features * 2,
+ bias=kwargs.get("bias", True),
+ )
+ self.linear2 = nn.Linear(
+ self.linear2.in_features,
+ self.linear2.out_features,
+ bias=kwargs.get("bias", True),
+ )
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
+ src2 = self.self_attn(
+ src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+ )[0]
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ # Use the provided activation function
+ if self.custom_activation in [ReGLU, GLU] or isinstance(
+ self.custom_activation, (ReGLU, GLU)
+ ):
+ src2 = self.linear2(self.custom_activation(self.linear1(src)))
+ else:
+ src2 = self.linear2(self.custom_activation(self.linear1(src)))
+
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
diff --git a/mambular/base_models/__init__.py b/mambular/base_models/__init__.py
index 7027074..6756093 100644
--- a/mambular/base_models/__init__.py
+++ b/mambular/base_models/__init__.py
@@ -5,6 +5,7 @@
from .mlp import MLP
from .resnet import ResNet
from .tabtransformer import TabTransformer
+from .mambatab import MambaTab
__all__ = [
"TaskModel",
@@ -14,4 +15,5 @@
"TabTransformer",
"MLP",
"BaseModel",
+ "MambaTab",
]
diff --git a/mambular/base_models/ft_transformer.py b/mambular/base_models/ft_transformer.py
index 45b3273..ddbf03c 100644
--- a/mambular/base_models/ft_transformer.py
+++ b/mambular/base_models/ft_transformer.py
@@ -9,6 +9,8 @@
InstanceNorm,
GroupNorm,
)
+from ..arch_utils.embedding_layer import EmbeddingLayer
+from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
from ..configs.fttransformer_config import DefaultFTTransformerConfig
from .basemodel import BaseModel
@@ -83,11 +85,7 @@ def __init__(
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info
- self.embedding_activation = self.hparams.get(
- "num_embedding_activation", config.num_embedding_activation
- )
-
- encoder_layer = nn.TransformerEncoderLayer(
+ encoder_layer = CustomTransformerEncoderLayer(
d_model=self.hparams.get("d_model", config.d_model),
nhead=self.hparams.get("n_heads", config.n_heads),
batch_first=True,
@@ -127,27 +125,19 @@ def __init__(
norm=self.norm_f,
)
- self.num_embeddings = nn.ModuleList(
- [
- nn.Sequential(
- nn.Linear(
- input_shape,
- self.hparams.get("d_model", config.d_model),
- bias=False,
- ),
- self.embedding_activation,
- )
- for feature_name, input_shape in num_feature_info.items()
- ]
- )
-
- self.cat_embeddings = nn.ModuleList(
- [
- nn.Embedding(
- num_categories + 1, self.hparams.get("d_model", config.d_model)
- )
- for feature_name, num_categories in cat_feature_info.items()
- ]
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding", config.layer_norm_after_embedding
+ ),
+ use_cls=True,
+ cls_position=0,
+ cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)
head_activation = self.hparams.get("head_activation", config.head_activation)
@@ -168,15 +158,6 @@ def __init__(
n_output_units=num_classes,
)
- self.cls_token = nn.Parameter(
- torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
- )
-
- if self.hparams.get("layer_norm_after_embedding"):
- self.embedding_norm = nn.LayerNorm(
- self.hparams.get("d_model", config.d_model)
- )
-
def forward(self, num_features, cat_features):
"""
Defines the forward pass of the model.
@@ -193,40 +174,7 @@ def forward(self, num_features, cat_features):
Tensor
The output predictions of the model.
"""
- batch_size = (
- cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
- )
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
-
- if len(self.cat_embeddings) > 0 and cat_features:
- cat_embeddings = [
- emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
- ]
- cat_embeddings = torch.stack(cat_embeddings, dim=1)
- cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
- if self.hparams.get("layer_norm_after_embedding"):
- cat_embeddings = self.embedding_norm(cat_embeddings)
- else:
- cat_embeddings = None
-
- if len(self.num_embeddings) > 0 and num_features:
- num_embeddings = [
- emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
- ]
- num_embeddings = torch.stack(num_embeddings, dim=1)
- if self.hparams.get("layer_norm_after_embedding"):
- num_embeddings = self.embedding_norm(num_embeddings)
- else:
- num_embeddings = None
-
- if cat_embeddings is not None and num_embeddings is not None:
- x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1)
- elif cat_embeddings is not None:
- x = torch.cat([cls_tokens, cat_embeddings], dim=1)
- elif num_embeddings is not None:
- x = torch.cat([cls_tokens, num_embeddings], dim=1)
- else:
- raise ValueError("No features provided to the model.")
+ x = self.embedding_layer(num_features, cat_features)
x = self.encoder(x)
diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py
index 39002d9..6d3f5c3 100644
--- a/mambular/base_models/lightning_wrapper.py
+++ b/mambular/base_models/lightning_wrapper.py
@@ -37,7 +37,7 @@ def __init__(
lss=False,
family=None,
loss_fct: callable = None,
- **kwargs
+ **kwargs,
):
super().__init__()
self.num_classes = num_classes
@@ -82,7 +82,7 @@ def __init__(
else:
output_dim = num_classes
- self.model = model_class(
+ self.base_model = model_class(
config=config,
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
@@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
Model output.
"""
- return self.model.forward(num_features, cat_features)
+ return self.base_model.forward(num_features, cat_features)
def compute_loss(self, predictions, y_true):
"""
@@ -126,7 +126,7 @@ def compute_loss(self, predictions, y_true):
Computed loss.
"""
if self.lss:
- return self.family.compute_loss(predictions, y_true)
+ return self.family.compute_loss(predictions, y_true.squeeze(-1))
else:
loss = self.loss_fct(predictions, y_true)
return loss
@@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
prog_bar=True,
logger=True,
)
- elif isinstance(self.loss_fct, nn.MSELoss):
- rmse = torch.sqrt(loss)
- self.log(
- "train_rmse",
- rmse,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
return loss
@@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_loss",
val_loss,
- on_step=True,
+ on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
@@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
self.log(
"val_acc",
acc,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
- elif isinstance(self.loss_fct, nn.MSELoss):
- rmse = torch.sqrt(val_loss)
- self.log(
- "val_rmse",
- rmse,
- on_step=True,
+ on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
@@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
self.log(
"test_acc",
acc,
- on_step=True,
- on_epoch=True,
- prog_bar=True,
- logger=True,
- )
- elif isinstance(self.loss_fct, nn.MSELoss):
- rmse = torch.sqrt(test_loss)
- self.log(
- "test_rmse",
- rmse,
- on_step=True,
+ on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
@@ -300,7 +270,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
- self.parameters(),
+ self.base_model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
diff --git a/mambular/base_models/mambatab.py b/mambular/base_models/mambatab.py
new file mode 100644
index 0000000..0e21dbe
--- /dev/null
+++ b/mambular/base_models/mambatab.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+from ..arch_utils.mamba_arch import Mamba
+from ..arch_utils.mlp_utils import MLP
+from ..arch_utils.normalization_layers import (
+ RMSNorm,
+ LayerNorm,
+ LearnableLayerScaling,
+ BatchNorm,
+ InstanceNorm,
+ GroupNorm,
+)
+from ..configs.mambatab_config import DefaultMambaTabConfig
+from .basemodel import BaseModel
+
+
+class MambaTab(BaseModel):
+ def __init__(
+ self,
+ cat_feature_info,
+ num_feature_info,
+ num_classes=1,
+ config: DefaultMambaTabConfig = DefaultMambaTabConfig(),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+
+ input_dim = 0
+ for feature_name, input_shape in num_feature_info.items():
+ input_dim += input_shape
+ for feature_name, input_shape in cat_feature_info.items():
+ input_dim += 1
+
+ self.lr = self.hparams.get("lr", config.lr)
+ self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
+ self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
+ self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
+ self.cat_feature_info = cat_feature_info
+ self.num_feature_info = num_feature_info
+
+ self.initial_layer = nn.Linear(input_dim, config.d_model)
+ self.norm_f = LayerNorm(config.d_model)
+
+ self.embedding_activation = self.hparams.get(
+ "num_embedding_activation", config.num_embedding_activation
+ )
+
+ self.axis = config.axis
+
+ head_activation = self.hparams.get("head_activation", config.head_activation)
+
+ self.tabular_head = MLP(
+ self.hparams.get("d_model", config.d_model),
+ hidden_units_list=self.hparams.get(
+ "head_layer_sizes", config.head_layer_sizes
+ ),
+ dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
+ use_skip_layers=self.hparams.get(
+ "head_skip_layers", config.head_skip_layers
+ ),
+ activation_fn=head_activation,
+ use_batch_norm=self.hparams.get(
+ "head_use_batch_norm", config.head_use_batch_norm
+ ),
+ n_output_units=num_classes,
+ )
+
+ self.mamba = Mamba(
+ d_model=self.hparams.get("d_model", config.d_model),
+ n_layers=self.hparams.get("n_layers", config.n_layers),
+ expand_factor=self.hparams.get("expand_factor", config.expand_factor),
+ bias=self.hparams.get("bias", config.bias),
+ d_conv=self.hparams.get("d_conv", config.d_conv),
+ conv_bias=self.hparams.get("conv_bias", config.conv_bias),
+ dropout=self.hparams.get("dropout", config.dropout),
+ dt_rank=self.hparams.get("dt_rank", config.dt_rank),
+ d_state=self.hparams.get("d_state", config.d_state),
+ dt_scale=self.hparams.get("dt_scale", config.dt_scale),
+ dt_init=self.hparams.get("dt_init", config.dt_init),
+ dt_max=self.hparams.get("dt_max", config.dt_max),
+ dt_min=self.hparams.get("dt_min", config.dt_min),
+ dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
+ activation=self.hparams.get("activation", config.activation),
+ bidirectional=False,
+ use_learnable_interaction=False,
+ )
+
+ def forward(self, num_features, cat_features):
+ x = num_features + cat_features
+ x = torch.cat(x, dim=1)
+
+ x = self.initial_layer(x)
+ if self.axis == 1:
+ x = x.unsqueeze(1)
+
+ else:
+ x = x.unsqueeze(0)
+
+ x = self.norm_f(x)
+ x = self.embedding_activation(x)
+ if self.axis == 1:
+ x = x.squeeze(1)
+ else:
+ x = x.squeeze(0)
+
+ preds = self.tabular_head(x)
+
+ return preds
diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py
index 31a18b9..d362b8a 100644
--- a/mambular/base_models/mambular.py
+++ b/mambular/base_models/mambular.py
@@ -12,6 +12,7 @@
)
from ..configs.mambular_config import DefaultMambularConfig
from .basemodel import BaseModel
+from ..arch_utils.embedding_layer import EmbeddingLayer
class Mambular(BaseModel):
@@ -81,13 +82,12 @@ def __init__(
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
+ self.shuffle_embeddings = self.hparams.get(
+ "shuffle_embeddings", config.shuffle_embeddings
+ )
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info
- self.embedding_activation = self.hparams.get(
- "num_embedding_activation", config.num_embedding_activation
- )
-
self.mamba = Mamba(
d_model=self.hparams.get("d_model", config.d_model),
n_layers=self.hparams.get("n_layers", config.n_layers),
@@ -109,19 +109,33 @@ def __init__(
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
+ AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay),
+ BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm),
+ layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
)
-
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
- self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
+ self.norm_f = RMSNorm(
+ self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
+ )
elif norm_layer == "LayerNorm":
- self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
+ self.norm_f = LayerNorm(
+ self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
+ )
elif norm_layer == "BatchNorm":
- self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
+ self.norm_f = BatchNorm(
+ self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
+ )
elif norm_layer == "InstanceNorm":
- self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
+ self.norm_f = InstanceNorm(
+ self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
+ )
elif norm_layer == "GroupNorm":
- self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
+ self.norm_f = GroupNorm(
+ 1,
+ self.hparams.get("d_model", config.d_model),
+ eps=config.layer_norm_eps,
+ )
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
@@ -129,27 +143,19 @@ def __init__(
else:
raise ValueError(f"Unsupported normalization layer: {norm_layer}")
- self.num_embeddings = nn.ModuleList(
- [
- nn.Sequential(
- nn.Linear(
- input_shape,
- self.hparams.get("d_model", config.d_model),
- bias=False,
- ),
- self.embedding_activation,
- )
- for feature_name, input_shape in num_feature_info.items()
- ]
- )
-
- self.cat_embeddings = nn.ModuleList(
- [
- nn.Embedding(
- num_categories + 1, self.hparams.get("d_model", config.d_model)
- )
- for feature_name, num_categories in cat_feature_info.items()
- ]
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding", config.layer_norm_after_embedding
+ ),
+ use_cls=False,
+ cls_position=-1,
+ cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)
head_activation = self.hparams.get("head_activation", config.head_activation)
@@ -170,17 +176,13 @@ def __init__(
n_output_units=num_classes,
)
- self.cls_token = nn.Parameter(
- torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
- )
-
- if self.hparams.get("layer_norm_after_embedding"):
- self.embedding_norm = nn.LayerNorm(
- self.hparams.get("d_model", config.d_model)
- )
+ if self.pooling_method == "cls":
+ self.use_cls = True
+ else:
+ self.use_cls = self.hparams.get("use_cls", config.use_cls)
- def __post__init(self):
- pass
+ if self.shuffle_embeddings:
+ self.perm = torch.randperm(self.embedding_layer.seq_len)
def forward(self, num_features, cat_features):
"""
@@ -198,40 +200,10 @@ def forward(self, num_features, cat_features):
Tensor
The output predictions of the model.
"""
- batch_size = (
- cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
- )
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
-
- if len(self.cat_embeddings) > 0 and cat_features:
- cat_embeddings = [
- emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
- ]
- cat_embeddings = torch.stack(cat_embeddings, dim=1)
- cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
- if self.hparams.get("layer_norm_after_embedding"):
- cat_embeddings = self.embedding_norm(cat_embeddings)
- else:
- cat_embeddings = None
-
- if len(self.num_embeddings) > 0 and num_features:
- num_embeddings = [
- emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
- ]
- num_embeddings = torch.stack(num_embeddings, dim=1)
- if self.hparams.get("layer_norm_after_embedding"):
- num_embeddings = self.embedding_norm(num_embeddings)
- else:
- num_embeddings = None
-
- if cat_embeddings is not None and num_embeddings is not None:
- x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1)
- elif cat_embeddings is not None:
- x = torch.cat([cls_tokens, cat_embeddings], dim=1)
- elif num_embeddings is not None:
- x = torch.cat([cls_tokens, num_embeddings], dim=1)
- else:
- raise ValueError("No features provided to the model.")
+ x = self.embedding_layer(num_features, cat_features)
+
+ if self.shuffle_embeddings:
+ x = x[:, self.perm, :]
x = self.mamba(x)
@@ -242,7 +214,9 @@ def forward(self, num_features, cat_features):
elif self.pooling_method == "sum":
x = torch.sum(x, dim=1)
elif self.pooling_method == "cls_token":
- x = x[:, 0]
+ x = x[:, -1]
+ elif self.pooling_method == "last":
+ x = x[:, -1]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
diff --git a/mambular/base_models/mlp.py b/mambular/base_models/mlp.py
index c84447b..9f61cab 100644
--- a/mambular/base_models/mlp.py
+++ b/mambular/base_models/mlp.py
@@ -10,6 +10,7 @@
InstanceNorm,
GroupNorm,
)
+from ..arch_utils.embedding_layer import EmbeddingLayer
class MLP(BaseModel):
@@ -39,12 +40,6 @@ def __init__(
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
- input_dim = 0
- for feature_name, input_shape in num_feature_info.items():
- input_dim += input_shape
- for feature_name, input_shape in cat_feature_info.items():
- input_dim += 1
-
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
@@ -59,6 +54,19 @@ def __init__(
)
self.use_glu = self.hparams.get("use_glu", config.use_glu)
self.activation = self.hparams.get("activation", config.activation)
+ self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings)
+
+ input_dim = 0
+ for feature_name, input_shape in num_feature_info.items():
+ input_dim += input_shape
+ for feature_name, input_shape in cat_feature_info.items():
+ input_dim += 1
+
+ if self.use_embeddings:
+ input_dim = (
+ len(num_feature_info) * config.d_model
+ + len(cat_feature_info) * config.d_model
+ )
# Input layer
self.layers.append(nn.Linear(input_dim, config.layer_sizes[0]))
@@ -110,6 +118,20 @@ def __init__(
# Output layer
self.layers.append(nn.Linear(config.layer_sizes[-1], num_classes))
+ if self.use_embeddings:
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding"
+ ),
+ use_cls=False,
+ )
+
def forward(self, num_features, cat_features) -> torch.Tensor:
"""
Forward pass of the MLP model.
@@ -124,8 +146,13 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
torch.Tensor
Output tensor.
"""
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ if self.use_embeddings:
+ x = self.embedding_layer(num_features, cat_features)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ else:
+ x = num_features + cat_features
+ x = torch.cat(x, dim=1)
for i in range(len(self.layers) - 1):
if isinstance(self.layers[i], nn.Linear):
diff --git a/mambular/base_models/resnet.py b/mambular/base_models/resnet.py
index e8c4981..a6a03b7 100644
--- a/mambular/base_models/resnet.py
+++ b/mambular/base_models/resnet.py
@@ -12,6 +12,7 @@
GroupNorm,
)
from ..arch_utils.resnet_utils import ResidualBlock
+from ..arch_utils.embedding_layer import EmbeddingLayer
class ResNet(BaseModel):
@@ -40,20 +41,27 @@ def __init__(
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
- input_dim = 0
- for feature_name, input_shape in num_feature_info.items():
- input_dim += input_shape
- for feature_name, input_shape in cat_feature_info.items():
- input_dim += 1
-
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info
-
self.activation = config.activation
+ self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings)
+
+ input_dim = 0
+ for feature_name, input_shape in num_feature_info.items():
+ input_dim += input_shape
+ for feature_name, input_shape in cat_feature_info.items():
+ input_dim += 1
+
+ if self.use_embeddings:
+ input_dim = (
+ len(num_feature_info) * config.d_model
+ + len(cat_feature_info) * config.d_model
+ )
+
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm
@@ -91,6 +99,20 @@ def __init__(
self.output_layer = nn.Linear(config.layer_sizes[-1], num_classes)
+ if self.use_embeddings:
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding"
+ ),
+ use_cls=False,
+ )
+
def forward(self, num_features, cat_features):
"""
Forward pass of the ResNet model.
@@ -107,8 +129,13 @@ def forward(self, num_features, cat_features):
torch.Tensor
Output tensor.
"""
- x = num_features + cat_features
- x = torch.cat(x, dim=1)
+ if self.use_embeddings:
+ x = self.embedding_layer(num_features, cat_features)
+ B, S, D = x.shape
+ x = x.reshape(B, S * D)
+ else:
+ x = num_features + cat_features
+ x = torch.cat(x, dim=1)
x = self.initial_layer(x)
for block in self.blocks:
diff --git a/mambular/base_models/tabtransformer.py b/mambular/base_models/tabtransformer.py
index 7fcd11f..d9c5052 100644
--- a/mambular/base_models/tabtransformer.py
+++ b/mambular/base_models/tabtransformer.py
@@ -9,8 +9,10 @@
InstanceNorm,
GroupNorm,
)
+from ..arch_utils.embedding_layer import EmbeddingLayer
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
from .basemodel import BaseModel
+from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
class TabTransformer(BaseModel):
@@ -74,6 +76,10 @@ def __init__(
):
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+ if cat_feature_info == {}:
+ raise ValueError(
+ "You are trying to fit a TabTransformer with no categorical features. Try using a different model that is better suited for tasks without categorical features."
+ )
layer_norm_dim = 0
for feature_name, input_shape in num_feature_info.items():
@@ -87,11 +93,7 @@ def __init__(
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info
- self.embedding_activation = self.hparams.get(
- "num_embedding_activation", config.num_embedding_activation
- )
-
- encoder_layer = nn.TransformerEncoderLayer(
+ encoder_layer = CustomTransformerEncoderLayer(
d_model=self.hparams.get("d_model", config.d_model),
nhead=self.hparams.get("n_heads", config.n_heads),
batch_first=True,
@@ -130,13 +132,19 @@ def __init__(
norm=self.norm_embedding,
)
- self.cat_embeddings = nn.ModuleList(
- [
- nn.Embedding(
- num_categories + 1, self.hparams.get("d_model", config.d_model)
- )
- for feature_name, num_categories in cat_feature_info.items()
- ]
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding", config.layer_norm_after_embedding
+ ),
+ use_cls=True,
+ cls_position=0,
+ cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
)
head_activation = self.hparams.get("head_activation", config.head_activation)
@@ -165,11 +173,6 @@ def __init__(
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
)
- if self.hparams.get("layer_norm_after_embedding"):
- self.embedding_norm = nn.LayerNorm(
- self.hparams.get("d_model", config.d_model)
- )
-
def forward(self, num_features, cat_features):
"""
Defines the forward pass of the model.
@@ -186,19 +189,7 @@ def forward(self, num_features, cat_features):
Tensor
The output predictions of the model.
"""
- if cat_features == []:
- raise ValueError("No categorical features provided.")
-
- if len(self.cat_embeddings) > 0 and cat_features:
- cat_embeddings = [
- emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
- ]
- cat_embeddings = torch.stack(cat_embeddings, dim=1)
- cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
- if self.hparams.get("layer_norm_after_embedding"):
- cat_embeddings = self.embedding_norm(cat_embeddings)
- else:
- cat_embeddings = None
+ cat_embeddings = self.embedding_layer({}, cat_features)
num_features = torch.cat(num_features, dim=1)
num_embeddings = self.norm_f(num_features)
@@ -211,6 +202,8 @@ def forward(self, num_features, cat_features):
x, _ = torch.max(x, dim=1)
elif self.pooling_method == "sum":
x = torch.sum(x, dim=1)
+ elif self.pooling_method == "cls":
+ x = x[:, 0]
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
diff --git a/mambular/base_models/tabularnn.py b/mambular/base_models/tabularnn.py
new file mode 100644
index 0000000..a3e31bc
--- /dev/null
+++ b/mambular/base_models/tabularnn.py
@@ -0,0 +1,153 @@
+import torch
+import torch.nn as nn
+from ..arch_utils.mlp_utils import MLP
+from ..configs.tabularnn_config import DefaultTabulaRNNConfig
+from .basemodel import BaseModel
+from ..arch_utils.embedding_layer import EmbeddingLayer
+from ..arch_utils.normalization_layers import (
+ RMSNorm,
+ LayerNorm,
+ LearnableLayerScaling,
+ BatchNorm,
+ InstanceNorm,
+ GroupNorm,
+)
+
+
+class TabulaRNN(BaseModel):
+ def __init__(
+ self,
+ cat_feature_info,
+ num_feature_info,
+ num_classes=1,
+ config: DefaultTabulaRNNConfig = DefaultTabulaRNNConfig(),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
+
+ self.lr = self.hparams.get("lr", config.lr)
+ self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
+ self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
+ self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
+ self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
+ self.cat_feature_info = cat_feature_info
+ self.num_feature_info = num_feature_info
+
+ norm_layer = self.hparams.get("norm", config.norm)
+ if norm_layer == "RMSNorm":
+ self.norm_f = RMSNorm(
+ self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ elif norm_layer == "LayerNorm":
+ self.norm_f = LayerNorm(
+ self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ elif norm_layer == "BatchNorm":
+ self.norm_f = BatchNorm(
+ self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ elif norm_layer == "InstanceNorm":
+ self.norm_f = InstanceNorm(
+ self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ elif norm_layer == "GroupNorm":
+ self.norm_f = GroupNorm(
+ 1, self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ elif norm_layer == "LearnableLayerScaling":
+ self.norm_f = LearnableLayerScaling(
+ self.hparams.get("dim_feedforward", config.dim_feedforward)
+ )
+ else:
+ self.norm_f = None
+
+ rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[config.model_type]
+ self.rnn = rnn_layer(
+ input_size=self.hparams.get("d_model", config.d_model),
+ hidden_size=self.hparams.get("dim_feedforward", config.dim_feedforward),
+ num_layers=self.hparams.get("n_layers", config.n_layers),
+ bidirectional=self.hparams.get("bidirectional", config.bidirectional),
+ batch_first=True,
+ dropout=self.hparams.get("rnn_dropout", config.rnn_dropout),
+ bias=self.hparams.get("bias", config.bias),
+ nonlinearity=(
+ self.hparams.get("rnn_activation", config.rnn_activation)
+ if config.model_type == "RNN"
+ else None
+ ),
+ )
+
+ self.embedding_layer = EmbeddingLayer(
+ num_feature_info=num_feature_info,
+ cat_feature_info=cat_feature_info,
+ d_model=self.hparams.get("d_model", config.d_model),
+ embedding_activation=self.hparams.get(
+ "embedding_activation", config.embedding_activation
+ ),
+ layer_norm_after_embedding=self.hparams.get(
+ "layer_norm_after_embedding", config.layer_norm_after_embedding
+ ),
+ use_cls=False,
+ cls_position=-1,
+ cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
+ )
+
+ head_activation = self.hparams.get("head_activation", config.head_activation)
+
+ self.tabular_head = MLP(
+ self.hparams.get("dim_feedforward", config.dim_feedforward),
+ hidden_units_list=self.hparams.get(
+ "head_layer_sizes", config.head_layer_sizes
+ ),
+ dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
+ use_skip_layers=self.hparams.get(
+ "head_skip_layers", config.head_skip_layers
+ ),
+ activation_fn=head_activation,
+ use_batch_norm=self.hparams.get(
+ "head_use_batch_norm", config.head_use_batch_norm
+ ),
+ n_output_units=num_classes,
+ )
+
+ self.linear = nn.Linear(config.d_model, config.dim_feedforward)
+
+ def forward(self, num_features, cat_features):
+ """
+ Defines the forward pass of the model.
+
+ Parameters
+ ----------
+ num_features : Tensor
+ Tensor containing the numerical features.
+ cat_features : Tensor
+ Tensor containing the categorical features.
+
+ Returns
+ -------
+ Tensor
+ The output predictions of the model.
+ """
+
+ x = self.embedding_layer(num_features, cat_features)
+ # RNN forward pass
+ out, _ = self.rnn(x)
+ z = self.linear(torch.mean(x, dim=1))
+
+ if self.pooling_method == "avg":
+ x = torch.mean(out, dim=1)
+ elif self.pooling_method == "max":
+ x, _ = torch.max(out, dim=1)
+ elif self.pooling_method == "sum":
+ x = torch.sum(out, dim=1)
+ elif self.pooling_method == "last":
+ x = x[:, -1, :]
+ else:
+ raise ValueError(f"Invalid pooling method: {self.pooling_method}")
+ x = x + z
+ if self.norm_f is not None:
+ x = self.norm_f(x)
+ preds = self.tabular_head(x)
+
+ return preds
diff --git a/mambular/configs/fttransformer_config.py b/mambular/configs/fttransformer_config.py
index 2e219ce..a433753 100644
--- a/mambular/configs/fttransformer_config.py
+++ b/mambular/configs/fttransformer_config.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import torch.nn as nn
+from ..arch_utils.transformer_utils import ReGLU
@dataclass
@@ -31,8 +32,8 @@ class DefaultFTTransformerConfig:
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -57,21 +58,23 @@ class DefaultFTTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
"""
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
- d_model: int = 64
- n_layers: int = 8
- n_heads: int = 4
- attn_dropout: float = 0.3
- ff_dropout: float = 0.3
- norm: str = "RMSNorm"
+ d_model: int = 128
+ n_layers: int = 4
+ n_heads: int = 8
+ attn_dropout: float = 0.2
+ ff_dropout: float = 0.1
+ norm: str = "LayerNorm"
activation: callable = nn.SELU()
- num_embedding_activation: callable = nn.Identity()
- head_layer_sizes: list = (128, 64, 32)
+ embedding_activation: callable = nn.Identity()
+ head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
@@ -80,6 +83,7 @@ class DefaultFTTransformerConfig:
pooling_method: str = "cls"
norm_first: bool = False
bias: bool = True
- transformer_activation: callable = nn.SELU()
+ transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
- transformer_dim_feedforward: int = 512
+ transformer_dim_feedforward: int = 256
+ cat_encoding: str = "int"
diff --git a/mambular/configs/mambatab_config.py b/mambular/configs/mambatab_config.py
new file mode 100644
index 0000000..3ebea6f
--- /dev/null
+++ b/mambular/configs/mambatab_config.py
@@ -0,0 +1,94 @@
+from dataclasses import dataclass
+import torch.nn as nn
+
+
+@dataclass
+class DefaultMambaTabConfig:
+ """
+ Configuration class for the Default Mambular model with predefined hyperparameters.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the model.
+ expand_factor : int, default=2
+ Expansion factor for the feed-forward layers.
+ bias : bool, default=False
+ Whether to use bias in the linear layers.
+ d_conv : int, default=16
+ Dimensionality of the convolutional layers.
+ conv_bias : bool, default=True
+ Whether to use bias in the convolutional layers.
+ dropout : float, default=0.05
+ Dropout rate for regularization.
+ dt_rank : str, default="auto"
+ Rank of the decision tree.
+ d_state : int, default=32
+ Dimensionality of the state in recurrent layers.
+ dt_scale : float, default=1.0
+ Scaling factor for decision tree.
+ dt_init : str, default="random"
+ Initialization method for decision tree.
+ dt_max : float, default=0.1
+ Maximum value for decision tree initialization.
+ dt_min : float, default=1e-04
+ Minimum value for decision tree initialization.
+ dt_init_floor : float, default=1e-04
+ Floor value for decision tree initialization.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the model.
+ num_embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ """
+
+ lr: float = 1e-04
+ lr_patience: int = 10
+ weight_decay: float = 1e-06
+ lr_factor: float = 0.1
+ d_model: int = 64
+ n_layers: int = 1
+ expand_factor: int = 2
+ bias: bool = False
+ d_conv: int = 16
+ conv_bias: bool = True
+ dropout: float = 0.05
+ dt_rank: str = "auto"
+ d_state: int = 128
+ dt_scale: float = 1.0
+ dt_init: str = "random"
+ dt_max: float = 0.1
+ dt_min: float = 1e-04
+ dt_init_floor: float = 1e-04
+ activation: callable = nn.ReLU()
+ num_embedding_activation: callable = nn.ReLU()
+ head_layer_sizes: list = ()
+ head_dropout: float = 0.0
+ head_skip_layers: bool = False
+ head_activation: callable = nn.ReLU()
+ head_use_batch_norm: bool = False
+ norm: str = "LayerNorm"
+ axis: int = 1
diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py
index 666750c..2083961 100644
--- a/mambular/configs/mambular_config.py
+++ b/mambular/configs/mambular_config.py
@@ -49,8 +49,8 @@ class DefaultMambularConfig:
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the model.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -69,6 +69,18 @@ class DefaultMambularConfig:
Whether to use bidirectional processing of the input sequences.
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through mamba blocks.
+ use_cls : bool, default=True
+ Whether to append a cls to the end of each 'sequence'.
+ shuffle_embeddings : bool, default=False.
+ Whether to shuffle the embeddings before being passed to the Mamba layers.
+ layer_norm_eps : float, default=1e-05
+ Epsilon value for layer normalization.
+ AD_weight_decay : bool, default=True
+ whether weight decay is also applied to A-D matrices.
+ BC_layer_norm: bool, default=False
+ whether to apply layer normalization to B-C matrices.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
"""
lr: float = 1e-04
@@ -76,23 +88,23 @@ class DefaultMambularConfig:
weight_decay: float = 1e-06
lr_factor: float = 0.1
d_model: int = 64
- n_layers: int = 8
+ n_layers: int = 4
expand_factor: int = 2
bias: bool = False
- d_conv: int = 16
+ d_conv: int = 4
conv_bias: bool = True
- dropout: float = 0.05
+ dropout: float = 0.0
dt_rank: str = "auto"
- d_state: int = 32
+ d_state: int = 128
dt_scale: float = 1.0
dt_init: str = "random"
dt_max: float = 0.1
dt_min: float = 1e-04
dt_init_floor: float = 1e-04
- norm: str = "RMSNorm"
- activation: callable = nn.SELU()
- num_embedding_activation: callable = nn.Identity()
- head_layer_sizes: list = (128, 64, 32)
+ norm: str = "LayerNorm"
+ activation: callable = nn.SiLU()
+ embedding_activation: callable = nn.Identity()
+ head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
@@ -101,3 +113,9 @@ class DefaultMambularConfig:
pooling_method: str = "avg"
bidirectional: bool = False
use_learnable_interaction: bool = False
+ use_cls: bool = False
+ shuffle_embeddings: bool = False
+ layer_norm_eps: float = 1e-05
+ AD_weight_decay: bool = True
+ BC_layer_norm: bool = False
+ cat_encoding: str = "int"
diff --git a/mambular/configs/mlp_config.py b/mambular/configs/mlp_config.py
index ee29bbf..adaef3c 100644
--- a/mambular/configs/mlp_config.py
+++ b/mambular/configs/mlp_config.py
@@ -35,13 +35,21 @@ class DefaultMLPConfig:
Whether to use batch normalization in the MLP layers.
layer_norm : bool, default=False
Whether to use layer normalization in the MLP layers.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
"""
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
- layer_sizes: list = (128, 128, 32)
+ layer_sizes: list = (256, 128, 32)
activation: callable = nn.SELU()
skip_layers: bool = False
dropout: float = 0.5
@@ -50,3 +58,7 @@ class DefaultMLPConfig:
skip_connections: bool = False
batch_norm: bool = False
layer_norm: bool = False
+ use_embeddings: bool = False
+ embedding_activation: callable = nn.Identity()
+ layer_norm_after_embedding: bool = False
+ d_model: int = 32
diff --git a/mambular/configs/resnet_config.py b/mambular/configs/resnet_config.py
index 8722a15..c2fb1bc 100644
--- a/mambular/configs/resnet_config.py
+++ b/mambular/configs/resnet_config.py
@@ -37,13 +37,21 @@ class DefaultResNetConfig:
Whether to use layer normalization in the ResNet layers.
num_blocks : int, default=3
Number of residual blocks in the ResNet.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
"""
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
- layer_sizes: list = (128, 128, 32)
+ layer_sizes: list = (256, 128, 32)
activation: callable = nn.SELU()
skip_layers: bool = False
dropout: float = 0.5
@@ -53,3 +61,7 @@ class DefaultResNetConfig:
batch_norm: bool = True
layer_norm: bool = False
num_blocks: int = 3
+ use_embeddings: bool = False
+ embedding_activation: callable = nn.Identity()
+ layer_norm_after_embedding: bool = False
+ d_model: int = 32
diff --git a/mambular/configs/tabtransformer_config.py b/mambular/configs/tabtransformer_config.py
index 866f8e4..a1131c9 100644
--- a/mambular/configs/tabtransformer_config.py
+++ b/mambular/configs/tabtransformer_config.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import torch.nn as nn
+from ..arch_utils.transformer_utils import ReGLU
@dataclass
@@ -31,8 +32,8 @@ class DefaultTabTransformerConfig:
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -57,21 +58,23 @@ class DefaultTabTransformerConfig:
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
"""
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
lr_factor: float = 0.1
- d_model: int = 64
- n_layers: int = 8
- n_heads: int = 4
- attn_dropout: float = 0.3
- ff_dropout: float = 0.3
- norm: str = "RMSNorm"
+ d_model: int = 128
+ n_layers: int = 4
+ n_heads: int = 8
+ attn_dropout: float = 0.2
+ ff_dropout: float = 0.1
+ norm: str = "LayerNorm"
activation: callable = nn.SELU()
- num_embedding_activation: callable = nn.Identity()
- head_layer_sizes: list = (128, 64, 32)
+ embedding_activation: callable = nn.Identity()
+ head_layer_sizes: list = ()
head_dropout: float = 0.5
head_skip_layers: bool = False
head_activation: callable = nn.SELU()
@@ -80,6 +83,7 @@ class DefaultTabTransformerConfig:
pooling_method: str = "avg"
norm_first: bool = True
bias: bool = True
- transformer_activation: callable = nn.SELU()
+ transformer_activation: callable = ReGLU()
layer_norm_eps: float = 1e-05
transformer_dim_feedforward: int = 512
+ cat_encoding: str = "int"
diff --git a/mambular/configs/tabularnn_config.py b/mambular/configs/tabularnn_config.py
new file mode 100644
index 0000000..700181c
--- /dev/null
+++ b/mambular/configs/tabularnn_config.py
@@ -0,0 +1,83 @@
+from dataclasses import dataclass
+import torch.nn as nn
+
+
+@dataclass
+class DefaultTabulaRNNConfig:
+ """
+ Configuration class for the default TabulaRNN model with predefined hyperparameters.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ model_type : str, default="RNN"
+ type of model, one of "RNN", "LSTM", "GRU"
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the transformer.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the transformer.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ pooling_method : str, default="cls"
+ Pooling method to be used ('cls', 'avg', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each transformer block.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : callable, default=nn.SELU()
+ Activation function for the transformer layers.
+ bidirectional : bool, default=False.
+ Whether to process data bidirectionally
+ cat_encoding : str, default="int"
+ Encoding method for categorical features.
+ """
+
+ lr: float = 1e-04
+ model_type: str = "RNN"
+ lr_patience: int = 10
+ weight_decay: float = 1e-06
+ lr_factor: float = 0.1
+ d_model: int = 128
+ n_layers: int = 4
+ rnn_dropout: float = 0.2
+ norm: str = "RMSNorm"
+ activation: callable = nn.SELU()
+ embedding_activation: callable = nn.Identity()
+ head_layer_sizes: list = ()
+ head_dropout: float = 0.5
+ head_skip_layers: bool = False
+ head_activation: callable = nn.SELU()
+ head_use_batch_norm: bool = False
+ layer_norm_after_embedding: bool = False
+ pooling_method: str = "avg"
+ norm_first: bool = False
+ bias: bool = True
+ rnn_activation: str = "relu"
+ layer_norm_eps: float = 1e-05
+ dim_feedforward: int = 256
+ numerical_embedding: str = "ple"
+ bidirectional: bool = False
+ cat_encoding: str = "int"
diff --git a/mambular/models/__init__.py b/mambular/models/__init__.py
index ee6e580..6b9f40c 100644
--- a/mambular/models/__init__.py
+++ b/mambular/models/__init__.py
@@ -1,13 +1,23 @@
-from .fttransformer import (FTTransformerClassifier, FTTransformerLSS,
- FTTransformerRegressor)
+from .fttransformer import (
+ FTTransformerClassifier,
+ FTTransformerLSS,
+ FTTransformerRegressor,
+)
from .mambular import MambularClassifier, MambularLSS, MambularRegressor
from .mlp import MLPLSS, MLPClassifier, MLPRegressor
from .resnet import ResNetClassifier, ResNetLSS, ResNetRegressor
from .sklearn_base_classifier import SklearnBaseClassifier
from .sklearn_base_lss import SklearnBaseLSS
from .sklearn_base_regressor import SklearnBaseRegressor
-from .tabtransformer import (TabTransformerClassifier, TabTransformerLSS,
- TabTransformerRegressor)
+from .tabtransformer import (
+ TabTransformerClassifier,
+ TabTransformerLSS,
+ TabTransformerRegressor,
+)
+
+from .mambatab import MambaTabClassifier, MambaTabRegressor, MambaTabLSS
+from .tabularnn import TabulaRNNClassifier, TabulaRNNRegressor, TabulaRNNLSS
+
__all__ = [
"MambularClassifier",
@@ -28,4 +38,10 @@
"SklearnBaseClassifier",
"SklearnBaseLSS",
"SklearnBaseRegressor",
+ "MambaTabRegressor",
+ "MambaTabClassifier",
+ "MambaTabLSS",
+ "TabulaRNNClassifier",
+ "TabulaRNNRegressor",
+ "TabulaRNNLSS",
]
diff --git a/mambular/models/fttransformer.py b/mambular/models/fttransformer.py
index 095c5ca..efd346e 100644
--- a/mambular/models/fttransformer.py
+++ b/mambular/models/fttransformer.py
@@ -38,8 +38,8 @@ class FTTransformerRegressor(SklearnBaseRegressor):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -64,6 +64,8 @@ class FTTransformerRegressor(SklearnBaseRegressor):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -145,8 +147,8 @@ class FTTransformerClassifier(SklearnBaseClassifier):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -171,6 +173,8 @@ class FTTransformerClassifier(SklearnBaseClassifier):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -252,8 +256,8 @@ class FTTransformerLSS(SklearnBaseLSS):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -278,6 +282,8 @@ class FTTransformerLSS(SklearnBaseLSS):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
diff --git a/mambular/models/mambatab.py b/mambular/models/mambatab.py
new file mode 100644
index 0000000..d8e6fad
--- /dev/null
+++ b/mambular/models/mambatab.py
@@ -0,0 +1,20 @@
+from .sklearn_base_regressor import SklearnBaseRegressor
+from .sklearn_base_lss import SklearnBaseLSS
+from .sklearn_base_classifier import SklearnBaseClassifier
+from ..base_models.mambatab import MambaTab
+from ..configs.mambatab_config import DefaultMambaTabConfig
+
+
+class MambaTabRegressor(SklearnBaseRegressor):
+ def __init__(self, **kwargs):
+ super().__init__(model=MambaTab, config=DefaultMambaTabConfig, **kwargs)
+
+
+class MambaTabClassifier(SklearnBaseClassifier):
+ def __init__(self, **kwargs):
+ super().__init__(model=MambaTab, config=DefaultMambaTabConfig, **kwargs)
+
+
+class MambaTabLSS(SklearnBaseLSS):
+ def __init__(self, **kwargs):
+ super().__init__(model=MambaTab, config=DefaultMambaTabConfig, **kwargs)
diff --git a/mambular/models/mambular.py b/mambular/models/mambular.py
index 11d6862..104448a 100644
--- a/mambular/models/mambular.py
+++ b/mambular/models/mambular.py
@@ -55,8 +55,8 @@ class MambularRegressor(SklearnBaseRegressor):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the model.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -75,6 +75,18 @@ class MambularRegressor(SklearnBaseRegressor):
Whether to use bidirectional processing of the input sequences.
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through mamba blocks.
+ use_cls : bool, default=True
+ Whether to append a cls to the end of each 'sequence'.
+ shuffle_embeddings : bool, default=False.
+ Whether to shuffle the embeddings before being passed to the Mamba layers.
+ layer_norm_eps : float, default=1e-05
+ Epsilon value for layer normalization.
+ AD_weight_decay : bool, default=True
+ whether weight decay is also applied to A-D matrices.
+ BC_layer_norm: bool, default=False
+ whether to apply layer normalization to B-C matrices.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -172,8 +184,8 @@ class MambularClassifier(SklearnBaseClassifier):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the model.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -192,6 +204,16 @@ class MambularClassifier(SklearnBaseClassifier):
Whether to use bidirectional processing of the input sequences.
use_learnable_interaction : bool, default=False
Whether to use learnable feature interactions before passing through mamba blocks.
+ shuffle_embeddings : bool, default=False.
+ Whether to shuffle the embeddings before being passed to the Mamba layers.
+ layer_norm_eps : float, default=1e-05
+ Epsilon value for layer normalization.
+ AD_weight_decay : bool, default=True
+ whether weight decay is also applied to A-D matrices.
+ BC_layer_norm: bool, default=False
+ whether to apply layer normalization to B-C matrices.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -289,8 +311,8 @@ class MambularLSS(SklearnBaseLSS):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the model.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -312,6 +334,16 @@ class MambularLSS(SklearnBaseLSS):
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ shuffle_embeddings : bool, default=False.
+ Whether to shuffle the embeddings before being passed to the Mamba layers.
+ layer_norm_eps : float, default=1e-05
+ Epsilon value for layer normalization.
+ AD_weight_decay : bool, default=True
+ whether weight decay is also applied to A-D matrices.
+ BC_layer_norm: bool, default=False
+ whether to apply layer normalization to B-C matrices.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
numerical_preprocessing : str, default="ple"
The preprocessing strategy for numerical features. Valid options are
'binning', 'one_hot', 'standardization', and 'normalization'.
diff --git a/mambular/models/mlp.py b/mambular/models/mlp.py
index 46d3d37..fb6baa9 100644
--- a/mambular/models/mlp.py
+++ b/mambular/models/mlp.py
@@ -41,6 +41,14 @@ class MLPRegressor(SklearnBaseRegressor):
Whether to use batch normalization in the MLP layers.
layer_norm : bool, default=False
Whether to use layer normalization in the MLP layers.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -124,6 +132,14 @@ class MLPClassifier(SklearnBaseClassifier):
Whether to use batch normalization in the MLP layers.
layer_norm : bool, default=False
Whether to use layer normalization in the MLP layers.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -207,6 +223,14 @@ class MLPLSS(SklearnBaseLSS):
Whether to use batch normalization in the MLP layers.
layer_norm : bool, default=False
Whether to use layer normalization in the MLP layers.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
diff --git a/mambular/models/resnet.py b/mambular/models/resnet.py
index 8f3cc96..1f6bc5f 100644
--- a/mambular/models/resnet.py
+++ b/mambular/models/resnet.py
@@ -43,6 +43,14 @@ class ResNetRegressor(SklearnBaseRegressor):
Whether to use layer normalization in the ResNet layers.
num_blocks : int, default=3
Number of residual blocks in the ResNet.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -128,6 +136,14 @@ class ResNetClassifier(SklearnBaseClassifier):
Whether to use layer normalization in the ResNet layers.
num_blocks : int, default=3
Number of residual blocks in the ResNet.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -213,6 +229,14 @@ class ResNetLSS(SklearnBaseLSS):
Whether to use layer normalization in the ResNet layers.
num_blocks : int, default=3
Number of residual blocks in the ResNet.
+ use_embeddings : bool, default=False
+ Whether to use embedding layers for all features.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ d_model : int, default=32
+ Dimensionality of the embeddings.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py
index 12ebe14..f442688 100644
--- a/mambular/models/sklearn_base_classifier.py
+++ b/mambular/models/sklearn_base_classifier.py
@@ -9,6 +9,8 @@
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
import numpy as np
+from lightning.pytorch.callbacks import ModelSummary
+from sklearn.metrics import log_loss
class SklearnBaseClassifier(BaseEstimator):
@@ -35,7 +37,7 @@ def __init__(self, model, config, **kwargs):
}
self.preprocessor = Preprocessor(**preprocessor_kwargs)
- self.model = None
+ self.task_model = None
# Raise a warning if task is set to 'classification'
if preprocessor_kwargs.get("task") == "regression":
@@ -45,26 +47,26 @@ def __init__(self, model, config, **kwargs):
)
self.base_model = model
+ self.built = False
def get_params(self, deep=True):
"""
- Get parameters for this estimator. Overrides the BaseEstimator method.
+ Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
- If True, returns the parameters for this estimator and contained sub-objects that are estimators.
+ If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
- params = self.config_kwargs # Parameters used to initialize DefaultConfig
+ params = {}
+ params.update(self.config_kwargs)
- # If deep=True, include parameters from nested components like preprocessor
if deep:
- # Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
@@ -75,63 +77,58 @@ def get_params(self, deep=True):
def set_params(self, **parameters):
"""
- Set the parameters of this estimator. Overrides the BaseEstimator method.
+ Set the parameters of this estimator.
Parameters
----------
**parameters : dict
- Estimator parameters to be set.
+ Estimator parameters.
Returns
-------
self : object
- The instance with updated parameters.
+ Estimator instance.
"""
- # Update config_kwargs with provided parameters
- valid_config_keys = self.config_kwargs.keys()
- config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
- self.config_kwargs.update(config_updates)
-
- # Update the config object
- for key, value in config_updates.items():
- setattr(self.config, key, value)
-
- # Handle preprocessor parameters (prefixed with 'preprocessor__')
+ config_params = {
+ k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ }
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}
+
+ if config_params:
+ self.config_kwargs.update(config_params)
+ if self.config is not None:
+ for key, value in config_params.items():
+ setattr(self.config, key, value)
+ else:
+ self.config = self.config_class(**self.config_kwargs)
+
if preprocessor_params:
- # Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)
return self
- def fit(
+ def build_model(
self,
X,
y,
val_size: float = 0.2,
X_val=None,
y_val=None,
- max_epochs: int = 100,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
- patience: int = 15,
- monitor: str = "val_loss",
- mode: str = "min",
lr: float = 1e-4,
lr_patience: int = 10,
factor: float = 0.1,
weight_decay: float = 1e-06,
- checkpoint_path="model_checkpoints",
dataloader_kwargs={},
- **trainer_kwargs
):
"""
- Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
+ Builds the model using the provided training data.
Parameters
----------
@@ -145,20 +142,12 @@ def fit(
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
The validation target values. Required if `X_val` is provided.
- max_epochs : int, default=100
- Maximum number of epochs for training.
random_state : int, default=101
Controls the shuffling applied to the data before applying the split.
batch_size : int, default=64
Number of samples per gradient update.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch.
- patience : int, default=10
- Number of epochs with no improvement on the validation loss to wait before early stopping.
- monitor : str, default="val_loss"
- The metric to monitor for early stopping.
- mode : str, default="min"
- Whether the monitored metric should be minimized (`min`) or maximized (`max`).
lr : float, default=1e-3
Learning rate for the optimizer.
lr_patience : int, default=10
@@ -167,17 +156,15 @@ def fit(
Factor by which the learning rate will be reduced.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
- checkpoint_path : str, default="model_checkpoints"
- Path where the checkpoints are being saved.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
- **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
+
Returns
-------
self : object
- The fitted regressor.
+ The built classifier.
"""
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
@@ -207,7 +194,7 @@ def fit(
num_classes = len(np.unique(y))
- self.model = TaskModel(
+ self.task_model = TaskModel(
model_class=self.base_model,
num_classes=num_classes,
config=self.config,
@@ -219,6 +206,157 @@ def fit(
weight_decay=weight_decay,
)
+ self.built = True
+
+ return self
+
+ def get_number_of_params(self, requires_grad=True):
+ """
+ Calculate the number of parameters in the model.
+
+ Parameters
+ ----------
+ requires_grad : bool, optional
+ If True, only count the parameters that require gradients (trainable parameters).
+ If False, count all parameters. Default is True.
+
+ Returns
+ -------
+ int
+ The total number of parameters in the model.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been built prior to calling this method.
+ """
+ if not self.built:
+ raise ValueError(
+ "The model must be built before the number of parameters can be estimated"
+ )
+ else:
+ if requires_grad:
+ return sum(
+ p.numel() for p in self.task_model.parameters() if p.requires_grad
+ )
+ else:
+ return sum(p.numel() for p in self.task_model.parameters())
+
+ def fit(
+ self,
+ X,
+ y,
+ val_size: float = 0.2,
+ X_val=None,
+ y_val=None,
+ max_epochs: int = 100,
+ random_state: int = 101,
+ batch_size: int = 128,
+ shuffle: bool = True,
+ patience: int = 15,
+ monitor: str = "val_loss",
+ mode: str = "min",
+ lr: float = 1e-4,
+ lr_patience: int = 10,
+ factor: float = 0.1,
+ weight_decay: float = 1e-06,
+ checkpoint_path="model_checkpoints",
+ dataloader_kwargs={},
+ rebuild=True,
+ **trainer_kwargs
+ ):
+ """
+ Trains the classification model using the provided training data. Optionally, a separate validation set can be used.
+
+ Parameters
+ ----------
+ X : DataFrame or array-like, shape (n_samples, n_features)
+ The training input samples.
+ y : array-like, shape (n_samples,) or (n_samples, n_targets)
+ The target values (real numbers).
+ val_size : float, default=0.2
+ The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
+ X_val : DataFrame or array-like, shape (n_samples, n_features), optional
+ The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
+ y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
+ The validation target values. Required if `X_val` is provided.
+ max_epochs : int, default=100
+ Maximum number of epochs for training.
+ random_state : int, default=101
+ Controls the shuffling applied to the data before applying the split.
+ batch_size : int, default=64
+ Number of samples per gradient update.
+ shuffle : bool, default=True
+ Whether to shuffle the training data before each epoch.
+ patience : int, default=10
+ Number of epochs with no improvement on the validation loss to wait before early stopping.
+ monitor : str, default="val_loss"
+ The metric to monitor for early stopping.
+ mode : str, default="min"
+ Whether the monitored metric should be minimized (`min`) or maximized (`max`).
+ lr : float, default=1e-3
+ Learning rate for the optimizer.
+ lr_patience : int, default=10
+ Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
+ factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ weight_decay : float, default=0.025
+ Weight decay (L2 penalty) coefficient.
+ checkpoint_path : str, default="model_checkpoints"
+ Path where the checkpoints are being saved.
+ dataloader_kwargs: dict, default={}
+ The kwargs for the pytorch dataloader class.
+ rebuild: bool, default=True
+ Whether to rebuild the model when it already was built.
+ **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
+
+
+ Returns
+ -------
+ self : object
+ The fitted classifier.
+ """
+ if not self.built and not rebuild:
+ if not isinstance(X, pd.DataFrame):
+ X = pd.DataFrame(X)
+ if isinstance(y, pd.Series):
+ y = y.values
+ if X_val:
+ if not isinstance(X_val, pd.DataFrame):
+ X_val = pd.DataFrame(X_val)
+ if isinstance(y_val, pd.Series):
+ y_val = y_val.values
+
+ self.data_module = MambularDataModule(
+ preprocessor=self.preprocessor,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ X_val=X_val,
+ y_val=y_val,
+ val_size=val_size,
+ random_state=random_state,
+ regression=False,
+ **dataloader_kwargs
+ )
+
+ self.data_module.preprocess_data(
+ X, y, X_val, y_val, val_size=val_size, random_state=random_state
+ )
+
+ num_classes = len(np.unique(y))
+
+ self.task_model = TaskModel(
+ model_class=self.base_model,
+ num_classes=num_classes,
+ config=self.config,
+ cat_feature_info=self.data_module.cat_feature_info,
+ num_feature_info=self.data_module.num_feature_info,
+ lr=lr,
+ lr_patience=lr_patience,
+ lr_factor=factor,
+ weight_decay=weight_decay,
+ )
+
early_stop_callback = EarlyStopping(
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
)
@@ -232,17 +370,21 @@ def fit(
)
# Initialize the trainer and train the model
- trainer = pl.Trainer(
+ self.trainer = pl.Trainer(
max_epochs=max_epochs,
- callbacks=[early_stop_callback, checkpoint_callback],
+ callbacks=[
+ early_stop_callback,
+ checkpoint_callback,
+ ModelSummary(max_depth=2),
+ ],
**trainer_kwargs
)
- trainer.fit(self.model, self.data_module)
+ self.trainer.fit(self.task_model, self.data_module)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.model.load_state_dict(checkpoint["state_dict"])
+ self.task_model.load_state_dict(checkpoint["state_dict"])
return self
@@ -262,14 +404,14 @@ def predict(self, X):
The predicted target values.
"""
# Ensure model and data module are initialized
- if self.model is None or self.data_module is None:
+ if self.task_model is None or self.data_module is None:
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
# Move tensors to appropriate device
- device = next(self.model.parameters()).device
+ device = next(self.task_model.parameters()).device
if isinstance(cat_tensors, list):
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
else:
@@ -281,11 +423,11 @@ def predict(self, X):
num_tensors = num_tensors.to(device)
# Set model to evaluation mode
- self.model.eval()
+ self.task_model.eval()
# Perform inference
with torch.no_grad():
- logits = self.model(num_features=num_tensors, cat_features=cat_tensors)
+ logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
# Check the shape of the logits to determine binary or multi-class classification
if logits.shape[1] == 1:
@@ -342,7 +484,7 @@ def predict_proba(self, X):
# Preprocess the data
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
- device = next(self.model.parameters()).device
+ device = next(self.task_model.parameters()).device
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
if isinstance(cat_tensors, list):
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
@@ -355,11 +497,11 @@ def predict_proba(self, X):
num_tensors = num_tensors.to(device)
# Set the model to evaluation mode
- self.model.eval()
+ self.task_model.eval()
# Perform inference
with torch.no_grad():
- logits = self.model(num_features=num_tensors, cat_features=cat_tensors)
+ logits = self.task_model(num_features=num_tensors, cat_features=cat_tensors)
if logits.shape[1] > 1:
probabilities = torch.softmax(logits, dim=1)
else:
@@ -419,3 +561,33 @@ def evaluate(self, X, y_true, metrics=None):
scores[metric_name] = metric_func(y_true, predictions)
return scores
+
+ def score(self, X, y, metric=(log_loss, True)):
+ """
+ Calculate the score of the model using the specified metric.
+
+ Parameters
+ ----------
+ X : array-like or pd.DataFrame of shape (n_samples, n_features)
+ The input samples to predict.
+ y : array-like of shape (n_samples,)
+ The true class labels against which to evaluate the predictions.
+ metric : tuple, default=(log_loss, True)
+ A tuple containing the metric function and a boolean indicating whether the metric requires probability scores (True) or class labels (False).
+
+ Returns
+ -------
+ score : float
+ The score calculated using the specified metric.
+ """
+ metric_func, use_proba = metric
+
+ if not isinstance(X, pd.DataFrame):
+ X = pd.DataFrame(X)
+
+ if use_proba:
+ probabilities = self.predict_proba(X)
+ return metric_func(y, probabilities)
+ else:
+ predictions = self.predict(X)
+ return metric_func(y, predictions)
diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py
index 7583055..ad7100f 100644
--- a/mambular/models/sklearn_base_lss.py
+++ b/mambular/models/sklearn_base_lss.py
@@ -31,6 +31,7 @@
PoissonDistribution,
StudentTDistribution,
)
+from lightning.pytorch.callbacks import ModelSummary
class SklearnBaseLSS(BaseEstimator):
@@ -57,7 +58,7 @@ def __init__(self, model, config, **kwargs):
}
self.preprocessor = Preprocessor(**preprocessor_kwargs)
- self.model = None
+ self.task_model = None
# Raise a warning if task is set to 'classification'
if preprocessor_kwargs.get("task") == "classification":
@@ -70,23 +71,22 @@ def __init__(self, model, config, **kwargs):
def get_params(self, deep=True):
"""
- Get parameters for this estimator. Overrides the BaseEstimator method.
+ Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
- If True, returns the parameters for this estimator and contained sub-objects that are estimators.
+ If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
- params = self.config_kwargs # Parameters used to initialize DefaultConfig
+ params = {}
+ params.update(self.config_kwargs)
- # If deep=True, include parameters from nested components like preprocessor
if deep:
- # Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
@@ -97,39 +97,169 @@ def get_params(self, deep=True):
def set_params(self, **parameters):
"""
- Set the parameters of this estimator. Overrides the BaseEstimator method.
+ Set the parameters of this estimator.
Parameters
----------
**parameters : dict
- Estimator parameters to be set.
+ Estimator parameters.
Returns
-------
self : object
- The instance with updated parameters.
+ Estimator instance.
"""
- # Update config_kwargs with provided parameters
- valid_config_keys = self.config_kwargs.keys()
- config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
- self.config_kwargs.update(config_updates)
-
- # Update the config object
- for key, value in config_updates.items():
- setattr(self.config, key, value)
-
- # Handle preprocessor parameters (prefixed with 'preprocessor__')
+ config_params = {
+ k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ }
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}
+
+ if config_params:
+ self.config_kwargs.update(config_params)
+ if self.config is not None:
+ for key, value in config_params.items():
+ setattr(self.config, key, value)
+ else:
+ self.config = self.config_class(**self.config_kwargs)
+
if preprocessor_params:
- # Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)
return self
+ def build_model(
+ self,
+ X,
+ y,
+ val_size: float = 0.2,
+ X_val=None,
+ y_val=None,
+ random_state: int = 101,
+ batch_size: int = 128,
+ shuffle: bool = True,
+ lr: float = 1e-4,
+ lr_patience: int = 10,
+ factor: float = 0.1,
+ weight_decay: float = 1e-06,
+ dataloader_kwargs={},
+ ):
+ """
+ Builds the model using the provided training data.
+
+ Parameters
+ ----------
+ X : DataFrame or array-like, shape (n_samples, n_features)
+ The training input samples.
+ y : array-like, shape (n_samples,) or (n_samples, n_targets)
+ The target values (real numbers).
+ val_size : float, default=0.2
+ The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
+ X_val : DataFrame or array-like, shape (n_samples, n_features), optional
+ The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
+ y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
+ The validation target values. Required if `X_val` is provided.
+ random_state : int, default=101
+ Controls the shuffling applied to the data before applying the split.
+ batch_size : int, default=64
+ Number of samples per gradient update.
+ shuffle : bool, default=True
+ Whether to shuffle the training data before each epoch.
+ lr : float, default=1e-3
+ Learning rate for the optimizer.
+ lr_patience : int, default=10
+ Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
+ factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ weight_decay : float, default=0.025
+ Weight decay (L2 penalty) coefficient.
+ dataloader_kwargs: dict, default={}
+ The kwargs for the pytorch dataloader class.
+
+ Returns
+ -------
+ self : object
+ The built distributional regressor.
+ """
+ if not isinstance(X, pd.DataFrame):
+ X = pd.DataFrame(X)
+ if isinstance(y, pd.Series):
+ y = y.values
+ if X_val:
+ if not isinstance(X_val, pd.DataFrame):
+ X_val = pd.DataFrame(X_val)
+ if isinstance(y_val, pd.Series):
+ y_val = y_val.values
+
+ self.data_module = MambularDataModule(
+ preprocessor=self.preprocessor,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ X_val=X_val,
+ y_val=y_val,
+ val_size=val_size,
+ random_state=random_state,
+ regression=False,
+ **dataloader_kwargs
+ )
+
+ self.data_module.preprocess_data(
+ X, y, X_val, y_val, val_size=val_size, random_state=random_state
+ )
+
+ num_classes = len(np.unique(y))
+
+ self.task_model = TaskModel(
+ model_class=self.base_model,
+ num_classes=num_classes,
+ config=self.config,
+ cat_feature_info=self.data_module.cat_feature_info,
+ num_feature_info=self.data_module.num_feature_info,
+ lr=lr,
+ lr_patience=lr_patience,
+ lr_factor=factor,
+ weight_decay=weight_decay,
+ )
+
+ self.built = True
+
+ return self
+
+ def get_number_of_params(self, requires_grad=True):
+ """
+ Calculate the number of parameters in the model.
+
+ Parameters
+ ----------
+ requires_grad : bool, optional
+ If True, only count the parameters that require gradients (trainable parameters).
+ If False, count all parameters. Default is True.
+
+ Returns
+ -------
+ int
+ The total number of parameters in the model.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been built prior to calling this method.
+ """
+ if not self.built:
+ raise ValueError(
+ "The model must be built before the number of parameters can be estimated"
+ )
+ else:
+ if requires_grad:
+ return sum(
+ p.numel() for p in self.task_model.parameters() if p.requires_grad
+ )
+ else:
+ return sum(p.numel() for p in self.task_model.parameters())
+
def fit(
self,
X,
@@ -253,7 +383,7 @@ def fit(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)
- self.model = TaskModel(
+ self.task_model = TaskModel(
model_class=self.base_model,
num_classes=self.family.param_count,
family=self.family,
@@ -280,17 +410,21 @@ def fit(
)
# Initialize the trainer and train the model
- trainer = pl.Trainer(
+ self.trainer = pl.Trainer(
max_epochs=max_epochs,
- callbacks=[early_stop_callback, checkpoint_callback],
+ callbacks=[
+ early_stop_callback,
+ checkpoint_callback,
+ ModelSummary(max_depth=2),
+ ],
**trainer_kwargs
)
- trainer.fit(self.model, self.data_module)
+ self.trainer.fit(self.task_model, self.data_module)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.model.load_state_dict(checkpoint["state_dict"])
+ self.task_model.load_state_dict(checkpoint["state_dict"])
return self
@@ -310,14 +444,14 @@ def predict(self, X, raw=False):
The predicted target values.
"""
# Ensure model and data module are initialized
- if self.model is None or self.data_module is None:
+ if self.task_model is None or self.data_module is None:
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
# Move tensors to appropriate device
- device = next(self.model.parameters()).device
+ device = next(self.task_model.parameters()).device
if isinstance(cat_tensors, list):
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
else:
@@ -329,14 +463,16 @@ def predict(self, X, raw=False):
num_tensors = num_tensors.to(device)
# Set model to evaluation mode
- self.model.eval()
+ self.task_model.eval()
# Perform inference
with torch.no_grad():
- predictions = self.model(num_features=num_tensors, cat_features=cat_tensors)
+ predictions = self.task_model(
+ num_features=num_tensors, cat_features=cat_tensors
+ )
if not raw:
- return self.model.family(predictions).cpu().numpy()
+ return self.task_model.family(predictions).cpu().numpy()
# Convert predictions to NumPy array and return
else:
@@ -372,7 +508,9 @@ def evaluate(self, X, y_true, metrics=None, distribution_family=None):
"""
# Infer distribution family from model settings if not provided
if distribution_family is None:
- distribution_family = getattr(self.model, "distribution_family", "normal")
+ distribution_family = getattr(
+ self.task_model, "distribution_family", "normal"
+ )
# Setup default metrics if none are provided
if metrics is None:
@@ -425,3 +563,25 @@ def get_default_metrics(self, distribution_family):
"categorical": {"Accuracy": accuracy_score},
}
return default_metrics.get(distribution_family, {})
+
+ def score(self, X, y, metric="NLL"):
+ """
+ Calculate the score of the model using the specified metric.
+
+ Parameters
+ ----------
+ X : array-like or pd.DataFrame of shape (n_samples, n_features)
+ The input samples to predict.
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
+ The true target values against which to evaluate the predictions.
+ metric : str, default="NLL"
+ So far, only negative log-likelihood is supported
+
+ Returns
+ -------
+ score : float
+ The score calculated using the specified metric.
+ """
+ predictions = self.predict(X)
+ score = self.task_model.family.evaluate_nll(y, predictions)
+ return score
diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py
index faa7674..1a098ac 100644
--- a/mambular/models/sklearn_base_regressor.py
+++ b/mambular/models/sklearn_base_regressor.py
@@ -8,11 +8,13 @@
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
+from lightning.pytorch.callbacks import ModelSummary
+from dataclasses import asdict, is_dataclass
class SklearnBaseRegressor(BaseEstimator):
def __init__(self, model, config, **kwargs):
- preprocessor_arg_names = [
+ self.preprocessor_arg_names = [
"n_bins",
"numerical_preprocessing",
"use_decision_tree_bins",
@@ -25,16 +27,18 @@ def __init__(self, model, config, **kwargs):
]
self.config_kwargs = {
- k: v for k, v in kwargs.items() if k not in preprocessor_arg_names
+ k: v for k, v in kwargs.items() if k not in self.preprocessor_arg_names
}
self.config = config(**self.config_kwargs)
preprocessor_kwargs = {
- k: v for k, v in kwargs.items() if k in preprocessor_arg_names
+ k: v for k, v in kwargs.items() if k in self.preprocessor_arg_names
}
self.preprocessor = Preprocessor(**preprocessor_kwargs)
- self.model = None
+ self.base_model = model
+ self.task_model = None
+ self.built = False
# Raise a warning if task is set to 'classification'
if preprocessor_kwargs.get("task") == "classification":
@@ -43,27 +47,24 @@ def __init__(self, model, config, **kwargs):
UserWarning,
)
- self.base_model = model
-
def get_params(self, deep=True):
"""
- Get parameters for this estimator. Overrides the BaseEstimator method.
+ Get parameters for this estimator.
Parameters
----------
deep : bool, default=True
- If True, returns the parameters for this estimator and contained sub-objects that are estimators.
+ If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
- params = self.config_kwargs # Parameters used to initialize DefaultConfig
+ params = {}
+ params.update(self.config_kwargs)
- # If deep=True, include parameters from nested components like preprocessor
if deep:
- # Assuming Preprocessor has a get_params method
preprocessor_params = {
"preprocessor__" + key: value
for key, value in self.preprocessor.get_params().items()
@@ -74,63 +75,58 @@ def get_params(self, deep=True):
def set_params(self, **parameters):
"""
- Set the parameters of this estimator. Overrides the BaseEstimator method.
+ Set the parameters of this estimator.
Parameters
----------
**parameters : dict
- Estimator parameters to be set.
+ Estimator parameters.
Returns
-------
self : object
- The instance with updated parameters.
+ Estimator instance.
"""
- # Update config_kwargs with provided parameters
- valid_config_keys = self.config_kwargs.keys()
- config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys}
- self.config_kwargs.update(config_updates)
-
- # Update the config object
- for key, value in config_updates.items():
- setattr(self.config, key, value)
-
- # Handle preprocessor parameters (prefixed with 'preprocessor__')
+ config_params = {
+ k: v for k, v in parameters.items() if not k.startswith("preprocessor__")
+ }
preprocessor_params = {
k.split("__")[1]: v
for k, v in parameters.items()
if k.startswith("preprocessor__")
}
+
+ if config_params:
+ self.config_kwargs.update(config_params)
+ if self.config is not None:
+ for key, value in config_params.items():
+ setattr(self.config, key, value)
+ else:
+ self.config = self.config_class(**self.config_kwargs)
+
if preprocessor_params:
- # Assuming Preprocessor has a set_params method
self.preprocessor.set_params(**preprocessor_params)
return self
- def fit(
+ def build_model(
self,
X,
y,
val_size: float = 0.2,
X_val=None,
y_val=None,
- max_epochs: int = 100,
random_state: int = 101,
batch_size: int = 128,
shuffle: bool = True,
- patience: int = 15,
- monitor: str = "val_loss",
- mode: str = "min",
lr: float = 1e-4,
lr_patience: int = 10,
factor: float = 0.1,
weight_decay: float = 1e-06,
- checkpoint_path="model_checkpoints",
dataloader_kwargs={},
- **trainer_kwargs
):
"""
- Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
+ Builds the model using the provided training data.
Parameters
----------
@@ -144,20 +140,12 @@ def fit(
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
The validation target values. Required if `X_val` is provided.
- max_epochs : int, default=100
- Maximum number of epochs for training.
random_state : int, default=101
Controls the shuffling applied to the data before applying the split.
batch_size : int, default=64
Number of samples per gradient update.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch.
- patience : int, default=10
- Number of epochs with no improvement on the validation loss to wait before early stopping.
- monitor : str, default="val_loss"
- The metric to monitor for early stopping.
- mode : str, default="min"
- Whether the monitored metric should be minimized (`min`) or maximized (`max`).
lr : float, default=1e-3
Learning rate for the optimizer.
lr_patience : int, default=10
@@ -166,17 +154,15 @@ def fit(
Factor by which the learning rate will be reduced.
weight_decay : float, default=0.025
Weight decay (L2 penalty) coefficient.
- checkpoint_path : str, default="model_checkpoints"
- Path where the checkpoints are being saved.
dataloader_kwargs: dict, default={}
The kwargs for the pytorch dataloader class.
- **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
+
Returns
-------
self : object
- The fitted regressor.
+ The built regressor.
"""
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
@@ -204,7 +190,7 @@ def fit(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)
- self.model = TaskModel(
+ self.task_model = TaskModel(
model_class=self.base_model,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
@@ -215,6 +201,155 @@ def fit(
weight_decay=weight_decay,
)
+ self.built = True
+
+ return self
+
+ def get_number_of_params(self, requires_grad=True):
+ """
+ Calculate the number of parameters in the model.
+
+ Parameters
+ ----------
+ requires_grad : bool, optional
+ If True, only count the parameters that require gradients (trainable parameters).
+ If False, count all parameters. Default is True.
+
+ Returns
+ -------
+ int
+ The total number of parameters in the model.
+
+ Raises
+ ------
+ ValueError
+ If the model has not been built prior to calling this method.
+ """
+ if not self.built:
+ raise ValueError(
+ "The model must be built before the number of parameters can be estimated"
+ )
+ else:
+ if requires_grad:
+ return sum(
+ p.numel() for p in self.task_model.parameters() if p.requires_grad
+ )
+ else:
+ return sum(p.numel() for p in self.task_model.parameters())
+
+ def fit(
+ self,
+ X,
+ y,
+ val_size: float = 0.2,
+ X_val=None,
+ y_val=None,
+ max_epochs: int = 100,
+ random_state: int = 101,
+ batch_size: int = 128,
+ shuffle: bool = True,
+ patience: int = 15,
+ monitor: str = "val_loss",
+ mode: str = "min",
+ lr: float = 1e-4,
+ lr_patience: int = 10,
+ factor: float = 0.1,
+ weight_decay: float = 1e-06,
+ checkpoint_path="model_checkpoints",
+ dataloader_kwargs={},
+ rebuild=True,
+ **trainer_kwargs
+ ):
+ """
+ Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
+
+ Parameters
+ ----------
+ X : DataFrame or array-like, shape (n_samples, n_features)
+ The training input samples.
+ y : array-like, shape (n_samples,) or (n_samples, n_targets)
+ The target values (real numbers).
+ val_size : float, default=0.2
+ The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
+ X_val : DataFrame or array-like, shape (n_samples, n_features), optional
+ The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
+ y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
+ The validation target values. Required if `X_val` is provided.
+ max_epochs : int, default=100
+ Maximum number of epochs for training.
+ random_state : int, default=101
+ Controls the shuffling applied to the data before applying the split.
+ batch_size : int, default=64
+ Number of samples per gradient update.
+ shuffle : bool, default=True
+ Whether to shuffle the training data before each epoch.
+ patience : int, default=10
+ Number of epochs with no improvement on the validation loss to wait before early stopping.
+ monitor : str, default="val_loss"
+ The metric to monitor for early stopping.
+ mode : str, default="min"
+ Whether the monitored metric should be minimized (`min`) or maximized (`max`).
+ lr : float, default=1e-3
+ Learning rate for the optimizer.
+ lr_patience : int, default=10
+ Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
+ factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ weight_decay : float, default=0.025
+ Weight decay (L2 penalty) coefficient.
+ checkpoint_path : str, default="model_checkpoints"
+ Path where the checkpoints are being saved.
+ dataloader_kwargs: dict, default={}
+ The kwargs for the pytorch dataloader class.
+ **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
+
+
+ Returns
+ -------
+ self : object
+ The fitted regressor.
+ """
+ if rebuild:
+ if not isinstance(X, pd.DataFrame):
+ X = pd.DataFrame(X)
+ if isinstance(y, pd.Series):
+ y = y.values
+ if X_val:
+ if not isinstance(X_val, pd.DataFrame):
+ X_val = pd.DataFrame(X_val)
+ if isinstance(y_val, pd.Series):
+ y_val = y_val.values
+
+ self.data_module = MambularDataModule(
+ preprocessor=self.preprocessor,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ X_val=X_val,
+ y_val=y_val,
+ val_size=val_size,
+ random_state=random_state,
+ regression=True,
+ **dataloader_kwargs
+ )
+
+ self.data_module.preprocess_data(
+ X, y, X_val, y_val, val_size=val_size, random_state=random_state
+ )
+
+ self.task_model = TaskModel(
+ model_class=self.base_model,
+ config=self.config,
+ cat_feature_info=self.data_module.cat_feature_info,
+ num_feature_info=self.data_module.num_feature_info,
+ lr=lr,
+ lr_patience=lr_patience,
+ lr_factor=factor,
+ weight_decay=weight_decay,
+ )
+
+ else:
+ assert self.built, "The model must be built before calling the fit method."
+
early_stop_callback = EarlyStopping(
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
)
@@ -228,17 +363,21 @@ def fit(
)
# Initialize the trainer and train the model
- trainer = pl.Trainer(
+ self.trainer = pl.Trainer(
max_epochs=max_epochs,
- callbacks=[early_stop_callback, checkpoint_callback],
+ callbacks=[
+ early_stop_callback,
+ checkpoint_callback,
+ ModelSummary(max_depth=2),
+ ],
**trainer_kwargs
)
- trainer.fit(self.model, self.data_module)
+ self.trainer.fit(self.task_model, self.data_module)
best_model_path = checkpoint_callback.best_model_path
if best_model_path:
checkpoint = torch.load(best_model_path)
- self.model.load_state_dict(checkpoint["state_dict"])
+ self.task_model.load_state_dict(checkpoint["state_dict"])
return self
@@ -258,14 +397,14 @@ def predict(self, X):
The predicted target values.
"""
# Ensure model and data module are initialized
- if self.model is None or self.data_module is None:
+ if self.task_model is None or self.data_module is None:
raise ValueError("The model or data module has not been fitted yet.")
# Preprocess the data using the data module
cat_tensors, num_tensors = self.data_module.preprocess_test_data(X)
# Move tensors to appropriate device
- device = next(self.model.parameters()).device
+ device = next(self.task_model.parameters()).device
if isinstance(cat_tensors, list):
cat_tensors = [tensor.to(device) for tensor in cat_tensors]
else:
@@ -277,11 +416,13 @@ def predict(self, X):
num_tensors = num_tensors.to(device)
# Set model to evaluation mode
- self.model.eval()
+ self.task_model.eval()
# Perform inference
with torch.no_grad():
- predictions = self.model(num_features=num_tensors, cat_features=cat_tensors)
+ predictions = self.task_model(
+ num_features=num_tensors, cat_features=cat_tensors
+ )
# Convert predictions to NumPy array and return
return predictions.cpu().numpy()
@@ -338,3 +479,24 @@ def evaluate(self, X, y_true, metrics=None):
scores[metric_name] = metric_func(y_true, predictions)
return scores
+
+ def score(self, X, y, metric=mean_squared_error):
+ """
+ Calculate the score of the model using the specified metric.
+
+ Parameters
+ ----------
+ X : array-like or pd.DataFrame of shape (n_samples, n_features)
+ The input samples to predict.
+ y : array-like of shape (n_samples,) or (n_samples, n_outputs)
+ The true target values against which to evaluate the predictions.
+ metric : callable, default=mean_squared_error
+ The metric function to use for evaluation. Must be a callable with the signature `metric(y_true, y_pred)`.
+
+ Returns
+ -------
+ score : float
+ The score calculated using the specified metric.
+ """
+ predictions = self.predict(X)
+ return metric(y, predictions)
diff --git a/mambular/models/tabtransformer.py b/mambular/models/tabtransformer.py
index e5cfe7d..901369e 100644
--- a/mambular/models/tabtransformer.py
+++ b/mambular/models/tabtransformer.py
@@ -37,8 +37,8 @@ class TabTransformerRegressor(SklearnBaseRegressor):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -63,6 +63,8 @@ class TabTransformerRegressor(SklearnBaseRegressor):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -144,8 +146,8 @@ class TabTransformerClassifier(SklearnBaseClassifier):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -170,6 +172,8 @@ class TabTransformerClassifier(SklearnBaseClassifier):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
@@ -251,8 +255,8 @@ class TabTransformerLSS(SklearnBaseLSS):
Normalization method to be used.
activation : callable, default=nn.SELU()
Activation function for the transformer.
- num_embedding_activation : callable, default=nn.Identity()
- Activation function for numerical embeddings.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for embeddings.
head_layer_sizes : list, default=(128, 64, 32)
Sizes of the layers in the head of the model.
head_dropout : float, default=0.5
@@ -277,6 +281,8 @@ class TabTransformerLSS(SklearnBaseLSS):
Epsilon value for layer normalization.
transformer_dim_feedforward : int, default=512
Dimensionality of the feed-forward layers in the transformer.
+ cat_encoding : str, default="int"
+ whether to use integer encoding or one-hot encoding for cat features.
n_bins : int, default=50
The number of bins to use for numerical feature binning. This parameter is relevant
only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
diff --git a/mambular/models/tabularnn.py b/mambular/models/tabularnn.py
new file mode 100644
index 0000000..60daf2a
--- /dev/null
+++ b/mambular/models/tabularnn.py
@@ -0,0 +1,255 @@
+from .sklearn_base_regressor import SklearnBaseRegressor
+from .sklearn_base_classifier import SklearnBaseClassifier
+from .sklearn_base_lss import SklearnBaseLSS
+
+from ..base_models.tabularnn import TabulaRNN
+from ..configs.tabularnn_config import DefaultTabulaRNNConfig
+
+
+class TabulaRNNRegressor(SklearnBaseRegressor):
+ """
+ RNN regressor. This class extends the SklearnBaseRegressor class and uses the TabulaRNN model
+ with the default TabulaRNN configuration.
+
+ The accepted arguments to the TabulaRNNRegressor class include both the attributes in the DefaultTabulaRNNConfig dataclass
+ and the parameters for the Preprocessor class.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ model_type : str, default="RNN"
+ type of model, one of "RNN", "LSTM", "GRU"
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the transformer.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the transformer.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ pooling_method : str, default="cls"
+ Pooling method to be used ('cls', 'avg', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each transformer block.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : callable, default=nn.SELU()
+ Activation function for the transformer layers.
+ bidirectional : bool, default=False.
+ Whether to process data bidirectionally
+ cat_encoding : str, default="int"
+ Encoding method for categorical features.
+ n_bins : int, default=50
+ The number of bins to use for numerical feature binning. This parameter is relevant
+ only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ numerical_preprocessing : str, default="ple"
+ The preprocessing strategy for numerical features. Valid options are
+ 'binning', 'one_hot', 'standardization', and 'normalization'.
+ use_decision_tree_bins : bool, default=False
+ If True, uses decision tree regression/classification to determine
+ optimal bin edges for numerical feature binning. This parameter is
+ relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ binning_strategy : str, default="uniform"
+ Defines the strategy for binning numerical features. Options include 'uniform',
+ 'quantile', or other sklearn-compatible strategies.
+ cat_cutoff : float or int, default=0.03
+ Indicates the cutoff after which integer values are treated as categorical.
+ If float, it's treated as a percentage. If int, it's the maximum number of
+ unique values for a column to be considered categorical.
+ treat_all_integers_as_numerical : bool, default=False
+ If True, all integer columns will be treated as numerical, regardless
+ of their unique value count or proportion.
+ degree : int, default=3
+ The degree of the polynomial features to be used in preprocessing.
+ knots : int, default=12
+ The number of knots to be used in spline transformations.
+ """
+
+
+class TabulaRNNClassifier(SklearnBaseClassifier):
+ """
+ RNN classifier. This class extends the SklearnBaseClassifier class and uses the TabulaRNN model
+ with the default TabulaRNN configuration.
+
+ The accepted arguments to the TabulaRNNClassifier class include both the attributes in the DefaultTabulaRNNConfig dataclass
+ and the parameters for the Preprocessor class.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ model_type : str, default="RNN"
+ type of model, one of "RNN", "LSTM", "GRU"
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the transformer.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the transformer.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ pooling_method : str, default="cls"
+ Pooling method to be used ('cls', 'avg', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each transformer block.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : callable, default=nn.SELU()
+ Activation function for the transformer layers.
+ bidirectional : bool, default=False.
+ Whether to process data bidirectionally
+ cat_encoding : str, default="int"
+ Encoding method for categorical features.
+ n_bins : int, default=50
+ The number of bins to use for numerical feature binning. This parameter is relevant
+ only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ numerical_preprocessing : str, default="ple"
+ The preprocessing strategy for numerical features. Valid options are
+ 'binning', 'one_hot', 'standardization', and 'normalization'.
+ use_decision_tree_bins : bool, default=False
+ If True, uses decision tree regression/classification to determine
+ optimal bin edges for numerical feature binning. This parameter is
+ relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ binning_strategy : str, default="uniform"
+ Defines the strategy for binning numerical features. Options include 'uniform',
+ 'quantile', or other sklearn-compatible strategies.
+ cat_cutoff : float or int, default=0.03
+ Indicates the cutoff after which integer values are treated as categorical.
+ If float, it's treated as a percentage. If int, it's the maximum number of
+ unique values for a column to be considered categorical.
+ treat_all_integers_as_numerical : bool, default=False
+ If True, all integer columns will be treated as numerical, regardless
+ of their unique value count or proportion.
+ degree : int, default=3
+ The degree of the polynomial features to be used in preprocessing.
+ knots : int, default=12
+ The number of knots to be used in spline transformations.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
+
+
+class TabulaRNNLSS(SklearnBaseLSS):
+ """
+ RNN LSS. This class extends the SklearnBaseLSS class and uses the TabulaRNN model
+ with the default TabulaRNN configuration.
+
+ The accepted arguments to the TabulaRNNLSS class include both the attributes in the DefaultTabulaRNNConfig dataclass
+ and the parameters for the Preprocessor class.
+
+ Parameters
+ ----------
+ lr : float, default=1e-04
+ Learning rate for the optimizer.
+ model_type : str, default="RNN"
+ type of model, one of "RNN", "LSTM", "GRU"
+ lr_patience : int, default=10
+ Number of epochs with no improvement after which learning rate will be reduced.
+ weight_decay : float, default=1e-06
+ Weight decay (L2 penalty) for the optimizer.
+ lr_factor : float, default=0.1
+ Factor by which the learning rate will be reduced.
+ d_model : int, default=64
+ Dimensionality of the model.
+ n_layers : int, default=8
+ Number of layers in the transformer.
+ norm : str, default="RMSNorm"
+ Normalization method to be used.
+ activation : callable, default=nn.SELU()
+ Activation function for the transformer.
+ embedding_activation : callable, default=nn.Identity()
+ Activation function for numerical embeddings.
+ head_layer_sizes : list, default=(128, 64, 32)
+ Sizes of the layers in the head of the model.
+ head_dropout : float, default=0.5
+ Dropout rate for the head layers.
+ head_skip_layers : bool, default=False
+ Whether to skip layers in the head.
+ head_activation : callable, default=nn.SELU()
+ Activation function for the head layers.
+ head_use_batch_norm : bool, default=False
+ Whether to use batch normalization in the head layers.
+ layer_norm_after_embedding : bool, default=False
+ Whether to apply layer normalization after embedding.
+ pooling_method : str, default="cls"
+ Pooling method to be used ('cls', 'avg', etc.).
+ norm_first : bool, default=False
+ Whether to apply normalization before other operations in each transformer block.
+ bias : bool, default=True
+ Whether to use bias in the linear layers.
+ rnn_activation : callable, default=nn.SELU()
+ Activation function for the transformer layers.
+ bidirectional : bool, default=False.
+ Whether to process data bidirectionally
+ cat_encoding : str, default="int"
+ Encoding method for categorical features.
+ n_bins : int, default=50
+ The number of bins to use for numerical feature binning. This parameter is relevant
+ only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ numerical_preprocessing : str, default="ple"
+ The preprocessing strategy for numerical features. Valid options are
+ 'binning', 'one_hot', 'standardization', and 'normalization'.
+ use_decision_tree_bins : bool, default=False
+ If True, uses decision tree regression/classification to determine
+ optimal bin edges for numerical feature binning. This parameter is
+ relevant only if `numerical_preprocessing` is set to 'binning' or 'one_hot'.
+ binning_strategy : str, default="uniform"
+ Defines the strategy for binning numerical features. Options include 'uniform',
+ 'quantile', or other sklearn-compatible strategies.
+ cat_cutoff : float or int, default=0.03
+ Indicates the cutoff after which integer values are treated as categorical.
+ If float, it's treated as a percentage. If int, it's the maximum number of
+ unique values for a column to be considered categorical.
+ treat_all_integers_as_numerical : bool, default=False
+ If True, all integer columns will be treated as numerical, regardless
+ of their unique value count or proportion.
+ degree : int, default=3
+ The degree of the polynomial features to be used in preprocessing.
+ knots : int, default=12
+ The number of knots to be used in spline transformations.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(model=TabulaRNN, config=DefaultTabulaRNNConfig, **kwargs)
diff --git a/mambular/preprocessing/preprocessor.py b/mambular/preprocessing/preprocessor.py
index f948877..c485f88 100644
--- a/mambular/preprocessing/preprocessor.py
+++ b/mambular/preprocessing/preprocessor.py
@@ -227,7 +227,9 @@ def fit(self, X, y=None):
numeric_transformer_steps.append(("scaler", StandardScaler()))
elif self.numerical_preprocessing == "normalization":
- numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
+ numeric_transformer_steps.append(
+ ("normalizer", MinMaxScaler(feature_range=(-1, 1)))
+ )
elif self.numerical_preprocessing == "quantile":
numeric_transformer_steps.append(
@@ -240,12 +242,15 @@ def fit(self, X, y=None):
)
elif self.numerical_preprocessing == "polynomial":
+ numeric_transformer_steps.append(("scaler", StandardScaler()))
numeric_transformer_steps.append(
(
"polynomial",
PolynomialFeatures(self.degree, include_bias=False),
)
)
+ # if self.degree > 10:
+ # numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
elif self.numerical_preprocessing == "splines":
numeric_transformer_steps.append(
@@ -260,13 +265,9 @@ def fit(self, X, y=None):
)
elif self.numerical_preprocessing == "ple":
- numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
numeric_transformer_steps.append(
- ("ple", PLE(n_bins=self.n_bins, task=self.task))
+ ("normalizer", MinMaxScaler(feature_range=(-1, 1)))
)
-
- elif self.numerical_preprocessing == "ple":
- numeric_transformer_steps.append(("normalizer", MinMaxScaler()))
numeric_transformer_steps.append(
("ple", PLE(n_bins=self.n_bins, task=self.task))
)
diff --git a/paper/paper.bib b/paper/paper.bib
deleted file mode 100644
index a0a9e1e..0000000
--- a/paper/paper.bib
+++ /dev/null
@@ -1,73 +0,0 @@
-@article{Gu,
- title={Mamba: Linear-time sequence modeling with selective state spaces},
- author={Gu, Albert and Dao, Tri},
- journal={arXiv preprint arXiv:2312.00752},
- year={2023}
-}
-
-
-@article{Ahamed,
- title={MambaTab: A Simple Yet Effective Approach for Handling Tabular Data},
- author={Ahamed, Md Atik and Cheng, Qiang},
- journal={arXiv preprint arXiv:2401.08867},
- year={2024}
-}
-
-
-@article{Gorishnyi1,
- title={Revisiting deep learning models for tabular data},
- author={Gorishniy, Yury and Rubachev, Ivan and Khrulkov, Valentin and Babenko, Artem},
- journal={Advances in Neural Information Processing Systems},
- volume={34},
- pages={18932--18943},
- year={2021}
-}
-
-
-@article{Huang,
- title={Tabtransformer: Tabular data modeling using contextual embeddings},
- author={Huang, Xin and Khetan, Ashish and Cvitkovic, Milan and Karnin, Zohar},
- journal={arXiv preprint arXiv:2012.06678},
- year={2020}
-}
-
-
-@inproceedings{Thielmann,
- title={Neural additive models for location scale and shape: A framework for interpretable neural regression beyond the mean},
- author={Thielmann, Anton Frederik and Kruse, Ren{\'e}-Marcel and Kneib, Thomas and S{\"a}fken, Benjamin},
- booktitle={International Conference on Artificial Intelligence and Statistics},
- pages={1783--1791},
- year={2024},
- organization={PMLR}
-}
-
-
-@article{Kneib,
- title={Rage against the mean--a review of distributional regression approaches},
- author={Kneib, Thomas and Silbersdorff, Alexander and S{\"a}fken, Benjamin},
- journal={Econometrics and Statistics},
- volume={26},
- pages={99--123},
- year={2023},
- publisher={Elsevier}
-}
-
-
-@article{Pedregosa,
- title={Scikit-learn: Machine learning in Python},
- author={Pedregosa, Fabian and Varoquaux, Ga{\"e}l and Gramfort, Alexandre and Michel, Vincent and Thirion, Bertrand and Grisel, Olivier and Blondel, Mathieu and Prettenhofer, Peter and Weiss, Ron and Dubourg, Vincent and others},
- journal={the Journal of machine Learning research},
- volume={12},
- pages={2825--2830},
- year={2011},
- publisher={JMLR. org}
-}
-
-@article{natt,
- title={Interpretable Additive Tabular Transformer Networks},
- author={Anton Frederik Thielmann and Arik Reuter and Thomas Kneib and David R{\"u}gamer and Benjamin S{\"a}fken},
- journal={Transactions on Machine Learning Research},
- issn={2835-8856},
- year={2024},
- url={https://openreview.net/forum?id=TdJ7lpzAkD},
-}
\ No newline at end of file
diff --git a/paper/paper.md b/paper/paper.md
deleted file mode 100644
index 0bd32e1..0000000
--- a/paper/paper.md
+++ /dev/null
@@ -1,70 +0,0 @@
----
-title: "Mambular: A User-Centric Python Library for Tabular Deep Learning Leveraging Mamba Architecture"
-tags:
- - Python
- - Tabular Deep Learning
- - Mamba
- - Distributional Regression
-authors:
- - name: Anton Frederik Thielmann
- orcid: 0000-0002-6768-8992
- affiliation: 1
- - name: Christoph Weisser
- affiliation: 1
- - name: Manish Kumar
- affiliation: 1
- - name: Benjamin Saefken
- affiliation: 2
- - name: Soheila Samiee
- affiliation: 3
-affiliations:
- - name: BASF SE, Germany
- index: 1
- - name: TU Clausthal, Germany
- index: 2
- - name: BASF Canada Inc, Canada
- index: 3
-date: 22 April 2024
-bibliography: paper.bib
----
-
-# 1. Summary
-
-Mambular is a Python library designed to leverage the capabilities of the recently proposed Mamba architecture [@Gu] for deep learning tasks involving tabular datasets. The effectiveness of the attention mechanism, as demonstrated by models such as TabTransformer [@Ahamed] and FT-Transformer [@Gorishnyi1], is extended to these data types, showcasing the potential for sequence-focused architectures to excel in this domain. Thus, sequence-focused architectures can also achieve state-of-the-art performances for tabular data problems. [@Huang] already demonstrated that the Mamba architecture, similar to the attention mechanism, can effectively be used when dealing with tabular data. Mambular closely follows [@Gorishnyi1], but uses Mamba blocks instead of transformer blocks.
-Furthermore, it offers enhanced flexibility in model architecture with respect to embedding activation, pooling layers, and task-specific head architectures. Choosing the appropriate settings, a user can thus easily implement the models presented in [@Huang].
-
-# 2. Statement of Need
-Transformer-based models for tabular data have become powerful alternatives to traditional gradient-based decision trees. [@Huang; @Gorishnyi1; @natt]. However, effectively training these models requires users to: **i)** deeply understand the intricacies of tabular transformer networks, **ii)** master various data type-dependent preprocessing techniques, **iii)** navigate complex deep learning libraries.
-This either leads researchers and practitioners alike to develop extensive custom scripts and libraries to fit these models or discourages them from using these advanced tools altogether. However, since tabular transformer models are becoming more popular and powerful, they should be easy to use, also for practitioners. Mambular addresses this by offering a straightforward framework that allows users to easily train tabular models using the innovative Mamba architecture.
-
-# 3. Methodology
-The Mambular default architecture, independent of the task follows the straight forward architecture of tabular tansformer models [@Ahamed; @Gorishnyi1; @Huang]:
-If the numerical features are integer binned they are treated as categorical features and each feature/variable is passed through an embedding layer. When other numerical preprocessing techniques are applied (or no preprocessing), the numerical features are passed through a single feed-forward dense layer with the same output dimensionality as the embedding layers [@Gorishnyi1]. By default, no activation is used on the created embeddings, but the users can easily change that with available arguments. The created embeddings are passed through a stack of Mamba layers after which the contextualized embeddings are pooled (default is average pooling). Mambular also offers the use of cls token embeddings instead of pooling layers. After pooling, RMS layer normalization from [@Gu] is applied by default, followed by a task-specific model head.
-
-### 3.1 Models
-Mambular includes the following three model classes:
-**i)** *MambularRegressor* for regression tasks, **ii)** *MambularClassifier* for classification tasks and **iii)** *MambularLSS* for distributional regression tasks, similar to [@Thielmann].^[ See e.g. [@Kneib] for an overview on distributional regression.]
-
-
-The loss functions are respectively the **i)** Mean squared error loss, **ii)** categorical cross entropy (Binary for binary classification) and **iii)** the negative log-likelihood for distributional regression. For **iii)** all distributional parameters have default activation/link functions that adhere to the distributional restrictions (e.g. positive variance for a normal distribution) but can be adapted to the users preferences. The inclusion of a distributional model focusing on regression beyond the mean further allows users to account for aleatoric uncertainty [@Kneib] without increasing the number of parameters or the complexity of the model.
-
-# 4. Ecosystem Compatibility and Flexibility
-
-Mambular is seamlessly compatible with the scikit-learn [@Pedregosa] ecosystem, allowing users to incorporate Mambular models into their existing workflows with minimal friction. This compatibility extends to various stages of the machine learning process, including data preprocessing, model training, evaluation, and hyperparameter tuning.
-
-Furthermore, Mambular's design emphasizes flexibility and user-friendliness. The library offers a range of customizable options for model architecture, including the choice of preprocessing, activation functions, pooling layers, normalization layers, regularization and more. This level of customization ensures that practitioners can tailor their models to the specific requirements of their tabular data tasks, optimizing performance and achieving state-of-the-art results as demonstrated by [@Ahamed].
-
-
-
-### 4.1 Preprocessing Capabilities
-
-Mambular includes a comprehensive preprocessing module also following scikit-learns preprocessing pipeline.
-The preprocessing module supports a wide range of data transformation techniques, including ordinal and one-hot encoding for categorical variables, decision tree-based binning for numerical features, and various strategies for handling missing values. By leveraging these preprocessing tools, users can ensure that their data is in the best possible shape for training Mambular models, leading to improved model performance.
-
-# Acknowledgements
-We sincerely acknowledge and appreciate the financial support provided by the Key Digital Capability (KDC) for Generative AI at BASF and the BASF Data & AI Academy, which played a critical role in facilitating this research.
-
-# References
-
-
-