Skip to content

Commit

Permalink
docstring update for MLPRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
mkumar73 committed Jun 28, 2024
1 parent 2463b01 commit fff7f44
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
2 changes: 0 additions & 2 deletions docs/api/models/Models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ mambular.models
.. autoclass:: mambular.models.MLPRegressor
:members:
:undoc-members:
:inherited-members:
:show-inheritance:

.. autoclass:: mambular.models.MLPLSS
:members:
Expand Down
19 changes: 19 additions & 0 deletions mambular/configs/mlp_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
from dataclasses import dataclass

import torch.nn as nn


@dataclass
class DefaultMLPConfig:
"""
Default configuration for a Multi-Layer Perceptron (MLP) model.
Attributes:
lr (float): The learning rate for the optimizer. Default is 1e-04.
lr_patience (int): The number of epochs to wait before reducing the learning rate. Default is 10.
weight_decay (float): The weight decay (L2 penalty) for the optimizer. Default is 1e-06.
lr_factor (float): The factor by which the learning rate is reduced. Default is 0.1.
layer_sizes (list): The sizes of the hidden layers in the MLP. Default is [128, 128, 32].
activation (callable): The activation function to use in the MLP. Default is nn.SELU().
skip_layers (bool): Whether to skip some layers in the MLP. Default is False.
dropout (float): The dropout probability for the MLP. Default is 0.5.
norm (str): The normalization method to use in the MLP. Default is None.
use_glu (bool): Whether to use Gated Linear Units (GLU) in the MLP. Default is False.
skip_connections (bool): Whether to use skip connections in the MLP. Default is False.
batch_norm (bool): Whether to use batch normalization in the MLP. Default is False.
layer_norm (bool): Whether to use layer normalization in the MLP. Default is False.
"""
lr: float = 1e-04
lr_patience: int = 10
weight_decay: float = 1e-06
Expand Down
79 changes: 76 additions & 3 deletions mambular/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,93 @@
from .sklearn_base_regressor import SklearnBaseRegressor
from .sklearn_base_classifier import SklearnBaseClassifier
from .sklearn_base_lss import SklearnBaseLSS
from ..base_models.mlp import MLP
from ..configs.mlp_config import DefaultMLPConfig
from .sklearn_base_classifier import SklearnBaseClassifier
from .sklearn_base_lss import SklearnBaseLSS
from .sklearn_base_regressor import SklearnBaseRegressor


class MLPRegressor(SklearnBaseRegressor):
"""Multi-Layer Perceptron regressor. This class extends the SklearnBaseRegressor class and uses the MLP model
with the default MLP configuration.
The accepted arguments to the MLPRegressor class are the same as the attributes in the DefaultMLPConfig dataclass.
Parameters:
**kwargs: Additional keyword arguments to be passed to the parent class. The kwargs should have the below attributes / congigurations.
model (class): The model class to use. Default is MLP.
lr (float): The learning rate for the optimizer. Default is 1e-04.
lr_patience (int): The number of epochs to wait before reducing the learning rate. Default is 10.
weight_decay (float, default: 1e-6): The weight decay (L2 penalty) for the optimizer. Default is 1e-06.
lr_factor (float): The factor by which the learning rate is reduced. Default is 0.1.
layer_sizes (list): The sizes of the hidden layers in the MLP. Default is [128, 128, 32].
activation (callable): The activation function to use in the MLP. Default is nn.SELU().
skip_layers (bool): Whether to skip some layers in the MLP. Default is False.
dropout (float): The dropout probability for the MLP. Default is 0.5.
norm (str): The normalization method to use in the MLP. Default is None.
use_glu (bool): Whether to use Gated Linear Units (GLU) in the MLP. Default is False.
skip_connections (bool): Whether to use skip connections in the MLP. Default is False.
batch_norm (bool): Whether to use batch normalization in the MLP. Default is False.
layer_norm (bool): Whether to use layer normalization in the MLP. Default is False.
Notes:
- The accepted arguments to the MLPRegressor class are the same as the attributes in the DefaultMLPConfig dataclass.
- MLPRegressor used SKlearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information.
See Also:
:class:`mambular.models.SklearnBaseRegressor`
Examples:
>>> from mambular.models import MLPRegressor
>>> model = MLPRegressor(model=MLP, lr=0.01, layer_sizes=[128, 128, 64])
>>> print(model)
MLPRegressor(model=MLP, config=DefaultMLPConfig(lr=0.01, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, layer_sizes=[128, 128, 64],
activation=SELU(), skip_layers=False, dropout=0.5, norm=None, use_glu=False, skip_connections=False, batch_norm=False, layer_norm=False))
>>> model.fit(X_train, y_train)
>>> preds = model.predict(X_test)
>>> model.evaluate(X_test, y_test)
"""

def __init__(self, **kwargs):
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs)


class MLPClassifier(SklearnBaseClassifier):
"""Multi-Layer Perceptron classifier.
This class extends the SklearnBaseClassifier class and uses the MLP model
with the default MLP configuration.
Parameters:
**kwargs: Additional keyword arguments to be passed to the parent class.
Attributes:
model: The MLP model.
config: The default MLP configuration.
"""

def __init__(self, **kwargs):
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs)


class MLPLSS(SklearnBaseLSS):
"""Multi-Layer Perceptron least squares solver.
This class extends the SklearnBaseLSS class and uses the MLP model
with the default MLP configuration.
Parameters:
**kwargs: Additional keyword arguments to be passed to the parent class.
Attributes:
model: The MLP model.
config: The default MLP configuration.
"""

def __init__(self, **kwargs):
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs)

0 comments on commit fff7f44

Please sign in to comment.