Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Mamba Speed #129

Open
Jasmine-ycj opened this issue Sep 27, 2024 · 5 comments
Open

Improve Mamba Speed #129

Jasmine-ycj opened this issue Sep 27, 2024 · 5 comments
Assignees
Labels
question Further information is requested

Comments

@Jasmine-ycj
Copy link

Jasmine-ycj commented Sep 27, 2024

Hello AnFreTh,

Thank you for your work on this project. I am currently using Mambular to process tabular data, but I am experiencing very slow training speeds. On average, each epoch is taking around 80 minutes to complete.

Here are the details of my setup:

  • Batch size: 256
  • GPU: NVIDIA 4090
  • Data: 8139 samples, each with 425 features
  • Model settings: default parameters

For comparison, when I use ResNet or FT-Transformer as tabular encoder with the same setup, the training speed is approximately 25 seconds per epoch, which is significantly faster. Is it expected that Mambular would be much slower than ResNet or FT Transformer? Or could this be an issue with my configuration or code?

I would appreciate any insight you could provide. Is there any known issue, or something I can adjust in my configuration to improve the speed?

Please let me know if you need additional information to help diagnose the problem.

Thank you for your time and assistance!

@Jasmine-ycj Jasmine-ycj added the question Further information is requested label Sep 27, 2024
@AnFreTh
Copy link
Collaborator

AnFreTh commented Sep 27, 2024

It is expected that Mambular is slower than e.g. FT-Transformer, especially for datasets with a lot of features, since training time increases linearly with sequence length (number of features). However, we experienced this by a factor of 2.5-3 while being more memory efficient than FT-Transformer.

Could you provide a minimal code example with simulated data where you experience similar training times? Then we can verify.

@Jasmine-ycj
Copy link
Author

Hello AnFreTh,

Thank you for your reply. Based on your suggestion, I have prepared a minimal code example for you to review.

In my current framework, I am using Mambular as the tabular encoder within a table-image contrastive learning setting. I defined a CustomMambularEncoder class, making the following modifications:

  1. I used an embedding method consistent with FT-Transformer due to differences in how the data is read.
  2. I removed the classification head, so the encoder outputs only the feature vectors from Mambular.

For simplicity, the provided code example only uses simulated tabular data. This dataset has 8139 samples in total, with 6530 samples split between the training and validation sets. Each sample consists of 423 numerical features only, with no categorical features.

When running this simplified code (with a batch size of 16), training the Mambular encoder takes approximately 2.5 hours per epoch, while using the FT-Transformer encoder takes around 15 seconds per epoch, and using ResNet as the encoder takes about 7 seconds per epoch.

I have attached the code example for your review. Please let me know if anything else is needed to further investigate the issue.

Thank you again for your help!

import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from mambular.arch_utils.mamba_arch import Mamba
from mambular.arch_utils.normalization_layers import (
    RMSNorm,
    LayerNorm,
    LearnableLayerScaling,
    BatchNorm,
    InstanceNorm,
    GroupNorm,
)
from mambular.configs.mambular_config import DefaultMambularConfig
from mambular.base_models.basemodel import BaseModel
from typing import List
from torch import Tensor
# embedding methods from FT-Transformer
from models.rtdl_revisiting_models import LinearEmbeddings, _CLSEmbedding

class CustomMambularEncoder(BaseModel):
    """
    Modified encoder based on Mambular:
    - the embedding layer is modified to be consistent with FT-Transformer.
    - the tabular head is removed.
    """

    def __init__(
        self,
        n_cont_features: int,
        cat_cardinalities: List[int],
        n_categories: List[int],
        config: DefaultMambularConfig = DefaultMambularConfig(),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

        self.lr = self.hparams.get("lr", config.lr)
        self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
        self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
        self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
        self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
        self.shuffle_embeddings = self.hparams.get(
            "shuffle_embeddings", config.shuffle_embeddings
        )

        self.mamba = Mamba(
            d_model=self.hparams.get("d_model", config.d_model),
            n_layers=self.hparams.get("n_layers", config.n_layers),
            expand_factor=self.hparams.get("expand_factor", config.expand_factor),
            bias=self.hparams.get("bias", config.bias),
            d_conv=self.hparams.get("d_conv", config.d_conv),
            conv_bias=self.hparams.get("conv_bias", config.conv_bias),
            dropout=self.hparams.get("dropout", config.dropout),
            dt_rank=self.hparams.get("dt_rank", config.dt_rank),
            d_state=self.hparams.get("d_state", config.d_state),
            dt_scale=self.hparams.get("dt_scale", config.dt_scale),
            dt_init=self.hparams.get("dt_init", config.dt_init),
            dt_max=self.hparams.get("dt_max", config.dt_max),
            dt_min=self.hparams.get("dt_min", config.dt_min),
            dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
            norm=globals()[self.hparams.get("norm", config.norm)],
            activation=self.hparams.get("activation", config.activation),
            bidirectional=self.hparams.get("bidiretional", config.bidirectional),
            use_learnable_interaction=self.hparams.get(
                "use_learnable_interactions", config.use_learnable_interaction
            ),
            AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay),
            BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm),
            layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
        )
        norm_layer = self.hparams.get("norm", config.norm)
        if norm_layer == "RMSNorm":
            self.norm_f = RMSNorm(
                self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
            )
        elif norm_layer == "LayerNorm":
            self.norm_f = LayerNorm(
                self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
            )
        elif norm_layer == "BatchNorm":
            self.norm_f = BatchNorm(
                self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
            )
        elif norm_layer == "InstanceNorm":
            self.norm_f = InstanceNorm(
                self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
            )
        elif norm_layer == "GroupNorm":
            self.norm_f = GroupNorm(
                1,
                self.hparams.get("d_model", config.d_model),
                eps=config.layer_norm_eps,
            )
        elif norm_layer == "LearnableLayerScaling":
            self.norm_f = LearnableLayerScaling(
                self.hparams.get("d_model", config.d_model)
            )
        else:
            raise ValueError(f"Unsupported normalization layer: {norm_layer}")

        # >>> Feature & cls embeddings in FT-Transformer.
        self.cont_embeddings = (
            LinearEmbeddings(n_cont_features+ len(cat_cardinalities), config.d_model) if n_cont_features > 0 else None
        )
        self.cls_embedding = _CLSEmbedding(config.d_model)
        # <<<

        if self.pooling_method == "cls":
            self.use_cls = True
        else:
            self.use_cls = self.hparams.get("use_cls", config.use_cls)

        if self.shuffle_embeddings:
            self.perm = torch.randperm(self.embedding_layer.seq_len)

    def forward(self, x):
        # cls embedding
        x_embeddings: List[Tensor] = []
        if self.cls_embedding is not None:
            x_embeddings.append(self.cls_embedding(x.shape[:-1]))

        # feature embedding, only numerical features in this case
        x_embeddings.append(self.cont_embeddings(x))
        
        x = torch.cat(x_embeddings, dim=1)
        
        if self.shuffle_embeddings:
            x = x[:, self.perm, :]

        x = self.mamba(x)

        if self.pooling_method == "avg":
            x = torch.mean(x, dim=1)
        elif self.pooling_method == "max":
            x, _ = torch.max(x, dim=1)
        elif self.pooling_method == "sum":
            x = torch.sum(x, dim=1)
        elif self.pooling_method == "cls_token":
            x = x[:, -1]
        elif self.pooling_method == "last":
            x = x[:, -1]
        else:
            raise ValueError(f"Invalid pooling method: {self.pooling_method}")

        x = self.norm_f(x)
        
        return x
    
class MinimalContrastiveMambularModel(pl.LightningModule):
    """
    Contrastive model for tabular data.
    """
    def __init__(self, feature_dim=128):
        super().__init__()
        self.mambular_encoder = CustomMambularEncoder(
            n_cont_features=423,
            cat_cardinalities=[],
            n_categories=[]
        )  
        self.projection_head = torch.nn.Sequential(
            torch.nn.Linear(64, feature_dim),  
            torch.nn.ReLU(),
            torch.nn.Linear(feature_dim, feature_dim)  
        )

    def forward(self, x):
        encoded = self.mambular_encoder(x)
        projected = self.projection_head(encoded)
        return F.normalize(projected, dim=1)  

    def training_step(self, batch, batch_idx):
        x1, x2 = batch  
        z1 = self.forward(x1)
        z2 = self.forward(x2)
        loss = self.contrastive_loss(z1, z2)
        self.log('train_loss', loss)
        return loss

    def contrastive_loss(self, z1, z2, temperature=0.5):
        # NT-Xent Loss 
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        batch_size = z1.shape[0]
        similarity_matrix = torch.matmul(z1, z2.T) / temperature
        labels = torch.arange(batch_size, device=z1.device)
        loss = F.cross_entropy(similarity_matrix, labels)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

# Simulated tabular data
# 6530 samples in the training & validation set
# 423 numerical features, 0 categorical features
simulated_data_1 = torch.rand(6530, 423)
simulated_data_2 = torch.rand(6530, 423)

# DataLoader
train_dataset = torch.utils.data.TensorDataset(simulated_data_1, simulated_data_2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16)

# training
model = MinimalContrastiveMambularModel()
trainer = pl.Trainer(max_epochs=5, gpus=1, limit_train_batches=1.0) 
torch.cuda.empty_cache()
trainer.fit(model, train_loader)

@AnFreTh
Copy link
Collaborator

AnFreTh commented Sep 28, 2024

I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.

@AnFreTh AnFreTh changed the title Inquiry About Slow Training Speed Improve Mamba Speed Sep 28, 2024
@Jasmine-ycj
Copy link
Author

I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.

Thank you for taking the time to investigate the issue. I will try these and look forward to your updates. Thanks again for your help and support!

@AnFreTh
Copy link
Collaborator

AnFreTh commented Sep 28, 2024

If you experiment further you could -instead of the python mamba implementation from Mambular- try out the original Mamba implementation: https://pypi.org/project/mamba-ssm/
If you do so, please let us know whether it improves speed :).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants