-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Mamba Speed #129
Comments
It is expected that Mambular is slower than e.g. FT-Transformer, especially for datasets with a lot of features, since training time increases linearly with sequence length (number of features). However, we experienced this by a factor of 2.5-3 while being more memory efficient than FT-Transformer. Could you provide a minimal code example with simulated data where you experience similar training times? Then we can verify. |
Hello AnFreTh, Thank you for your reply. Based on your suggestion, I have prepared a minimal code example for you to review. In my current framework, I am using Mambular as the tabular encoder within a table-image contrastive learning setting. I defined a
For simplicity, the provided code example only uses simulated tabular data. This dataset has 8139 samples in total, with 6530 samples split between the training and validation sets. Each sample consists of 423 numerical features only, with no categorical features. When running this simplified code (with a batch size of 16), training the Mambular encoder takes approximately 2.5 hours per epoch, while using the FT-Transformer encoder takes around 15 seconds per epoch, and using ResNet as the encoder takes about 7 seconds per epoch. I have attached the code example for your review. Please let me know if anything else is needed to further investigate the issue. Thank you again for your help!
|
I could not recreate the extreme differences you reported, but still using default Mambular was 10x slower than FTTransformer for this specific setup. We will update the current Mambablock implementation to increase speed.
|
Thank you for taking the time to investigate the issue. I will try these and look forward to your updates. Thanks again for your help and support! |
If you experiment further you could -instead of the python mamba implementation from Mambular- try out the original Mamba implementation: https://pypi.org/project/mamba-ssm/ |
Hello AnFreTh,
Thank you for your work on this project. I am currently using Mambular to process tabular data, but I am experiencing very slow training speeds. On average, each epoch is taking around 80 minutes to complete.
Here are the details of my setup:
For comparison, when I use ResNet or FT-Transformer as tabular encoder with the same setup, the training speed is approximately 25 seconds per epoch, which is significantly faster. Is it expected that Mambular would be much slower than ResNet or FT Transformer? Or could this be an issue with my configuration or code?
I would appreciate any insight you could provide. Is there any known issue, or something I can adjust in my configuration to improve the speed?
Please let me know if you need additional information to help diagnose the problem.
Thank you for your time and assistance!
The text was updated successfully, but these errors were encountered: