diff --git a/minerva/models/nets/__init__.py b/minerva/models/nets/__init__.py index c6c35d5..bbfdb31 100644 --- a/minerva/models/nets/__init__.py +++ b/minerva/models/nets/__init__.py @@ -4,7 +4,6 @@ from .image.unet import UNet from .image.wisenet import WiseNet from .mlp import MLP -from .lfr import LearnFromRandomnessModel, RepeatedModuleList __all__ = [ "SimpleSupervisedModel", @@ -12,7 +11,5 @@ "SETR_PUP", "UNet", "WiseNet", - "MLP", - "LearnFromRandomnessModel", - "RepeatedModuleList" + "MLP" ] diff --git a/minerva/models/nets/mlp.py b/minerva/models/nets/mlp.py index 8452749..b59e2c0 100644 --- a/minerva/models/nets/mlp.py +++ b/minerva/models/nets/mlp.py @@ -1,12 +1,14 @@ from torch import nn +from typing import Sequence + class MLP(nn.Sequential): """ A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. - - The MLP consists of a series of linear layers interleaved with ReLU activation functions, - except for the last layer which is purely linear. - + + This MLP is composed of a sequence of linear layers interleaved with ReLU activation + functions, except for the final layer which remains purely linear. + Example ------- @@ -21,32 +23,51 @@ class MLP(nn.Sequential): ) """ - def __init__(self, *layer_sizes: int): + def __init__( + self, + layer_sizes: Sequence[int], + activation_cls: type = nn.ReLU, + *args, + **kwargs + ): """ - Initializes the MLP with the given layer sizes. + Initializes the MLP with specified layer sizes. Parameters ---------- - *layer_sizes: int - A variable number of positive integers specifying the size of each layer. - There must be at least two integers, representing the input and output layers. - + layer_sizes : Sequence[int] + A sequence of positive integers indicating the size of each layer. + At least two integers are required, representing the input and output layers. + activation_cls : type + The class of the activation function to use between layers. Default is nn.ReLU. + *args + Additional arguments passed to the activation function. + **kwargs + Additional keyword arguments passed to the activation function. + Raises ------ - AssertionError: If less than two layer sizes are provided. - - AssertionError: If any layer size is not a positive integer. + AssertionError + If fewer than two layer sizes are provided or if any layer size is not a positive integer. + AssertionError + If activation_cls does not inherit from torch.nn.Module. """ + assert ( len(layer_sizes) >= 2 ), "Multilayer perceptron must have at least 2 layers" assert all( ls > 0 and isinstance(ls, int) for ls in layer_sizes - ), "All layer sizes must be a positive integer" + ), "All layer sizes must be positive integers" + + assert issubclass( + activation_cls, nn.Module + ), "activation_cls must inherit from torch.nn.Module" layers = [] for i in range(len(layer_sizes) - 2): - layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()] - layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])] + layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) + layers.append(activation_cls(*args, **kwargs)) + layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1])) super().__init__(*layers) diff --git a/minerva/models/ssl/__init__.py b/minerva/models/ssl/__init__.py new file mode 100644 index 0000000..c14ad54 --- /dev/null +++ b/minerva/models/ssl/__init__.py @@ -0,0 +1,6 @@ +from .lfr import LearnFromRandomnessModel, RepeatedModuleList + +__all__ = [ + "LearnFromRandomnessModel", + "RepeatedModuleList" +] \ No newline at end of file diff --git a/minerva/models/nets/lfr.py b/minerva/models/ssl/lfr.py similarity index 100% rename from minerva/models/nets/lfr.py rename to minerva/models/ssl/lfr.py diff --git a/tests/models/nets/test_lfr.py b/tests/models/ssl/test_lfr.py similarity index 95% rename from tests/models/nets/test_lfr.py rename to tests/models/ssl/test_lfr.py index 666e050..c93cfec 100644 --- a/tests/models/nets/test_lfr.py +++ b/tests/models/ssl/test_lfr.py @@ -2,7 +2,7 @@ from torch.nn import Sequential, Conv2d, CrossEntropyLoss from torchvision.transforms import Resize -from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel +from minerva.models.ssl.lfr import RepeatedModuleList, LearnFromRandomnessModel from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone @@ -57,4 +57,3 @@ def __init__(self): # Test the configure_optimizers method optimizer = model.configure_optimizers() assert optimizer is not None -