-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
95 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,93 @@ | ||
from .sklearn_base_regressor import SklearnBaseRegressor | ||
from .sklearn_base_classifier import SklearnBaseClassifier | ||
from .sklearn_base_lss import SklearnBaseLSS | ||
from ..base_models.mlp import MLP | ||
from ..configs.mlp_config import DefaultMLPConfig | ||
from .sklearn_base_classifier import SklearnBaseClassifier | ||
from .sklearn_base_lss import SklearnBaseLSS | ||
from .sklearn_base_regressor import SklearnBaseRegressor | ||
|
||
|
||
class MLPRegressor(SklearnBaseRegressor): | ||
"""Multi-Layer Perceptron regressor. This class extends the SklearnBaseRegressor class and uses the MLP model | ||
with the default MLP configuration. | ||
The accepted arguments to the MLPRegressor class are the same as the attributes in the DefaultMLPConfig dataclass. | ||
Parameters: | ||
**kwargs: Additional keyword arguments to be passed to the parent class. The kwargs should have the below attributes / congigurations. | ||
model (class): The model class to use. Default is MLP. | ||
lr (float): The learning rate for the optimizer. Default is 1e-04. | ||
lr_patience (int): The number of epochs to wait before reducing the learning rate. Default is 10. | ||
weight_decay (float, default: 1e-6): The weight decay (L2 penalty) for the optimizer. Default is 1e-06. | ||
lr_factor (float): The factor by which the learning rate is reduced. Default is 0.1. | ||
layer_sizes (list): The sizes of the hidden layers in the MLP. Default is [128, 128, 32]. | ||
activation (callable): The activation function to use in the MLP. Default is nn.SELU(). | ||
skip_layers (bool): Whether to skip some layers in the MLP. Default is False. | ||
dropout (float): The dropout probability for the MLP. Default is 0.5. | ||
norm (str): The normalization method to use in the MLP. Default is None. | ||
use_glu (bool): Whether to use Gated Linear Units (GLU) in the MLP. Default is False. | ||
skip_connections (bool): Whether to use skip connections in the MLP. Default is False. | ||
batch_norm (bool): Whether to use batch normalization in the MLP. Default is False. | ||
layer_norm (bool): Whether to use layer normalization in the MLP. Default is False. | ||
Notes: | ||
- The accepted arguments to the MLPRegressor class are the same as the attributes in the DefaultMLPConfig dataclass. | ||
- MLPRegressor used SKlearnBaseRegressor as the parent class. The methods for fitting, predicting, and evaluating the model are inherited from the parent class. Please refer to the parent class for more information. | ||
See Also: | ||
:class:`mambular.models.SklearnBaseRegressor` | ||
Examples: | ||
>>> from mambular.models import MLPRegressor | ||
>>> model = MLPRegressor(model=MLP, lr=0.01, layer_sizes=[128, 128, 64]) | ||
>>> print(model) | ||
MLPRegressor(model=MLP, config=DefaultMLPConfig(lr=0.01, lr_patience=10, weight_decay=1e-06, lr_factor=0.1, layer_sizes=[128, 128, 64], | ||
activation=SELU(), skip_layers=False, dropout=0.5, norm=None, use_glu=False, skip_connections=False, batch_norm=False, layer_norm=False)) | ||
>>> model.fit(X_train, y_train) | ||
>>> preds = model.predict(X_test) | ||
>>> model.evaluate(X_test, y_test) | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs) | ||
|
||
|
||
class MLPClassifier(SklearnBaseClassifier): | ||
"""Multi-Layer Perceptron classifier. | ||
This class extends the SklearnBaseClassifier class and uses the MLP model | ||
with the default MLP configuration. | ||
Parameters: | ||
**kwargs: Additional keyword arguments to be passed to the parent class. | ||
Attributes: | ||
model: The MLP model. | ||
config: The default MLP configuration. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs) | ||
|
||
|
||
class MLPLSS(SklearnBaseLSS): | ||
"""Multi-Layer Perceptron least squares solver. | ||
This class extends the SklearnBaseLSS class and uses the MLP model | ||
with the default MLP configuration. | ||
Parameters: | ||
**kwargs: Additional keyword arguments to be passed to the parent class. | ||
Attributes: | ||
model: The MLP model. | ||
config: The default MLP configuration. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(model=MLP, config=DefaultMLPConfig, **kwargs) |