Skip to content

Commit

Permalink
moved lfr to ssl folder, added parameters to mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandoGubiMarques authored and fernandoGubiMarques committed Jul 12, 2024
1 parent fd22265 commit 0d0e5f8
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 22 deletions.
5 changes: 1 addition & 4 deletions minerva/models/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
from .image.unet import UNet
from .image.wisenet import WiseNet
from .mlp import MLP
from .lfr import LearnFromRandomnessModel, RepeatedModuleList

__all__ = [
"SimpleSupervisedModel",
"DeepLabV3",
"SETR_PUP",
"UNet",
"WiseNet",
"MLP",
"LearnFromRandomnessModel",
"RepeatedModuleList"
"MLP"
]
53 changes: 37 additions & 16 deletions minerva/models/nets/mlp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from torch import nn
from typing import Sequence


class MLP(nn.Sequential):
"""
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.
The MLP consists of a series of linear layers interleaved with ReLU activation functions,
except for the last layer which is purely linear.
This MLP is composed of a sequence of linear layers interleaved with ReLU activation
functions, except for the final layer which remains purely linear.
Example
-------
Expand All @@ -21,32 +23,51 @@ class MLP(nn.Sequential):
)
"""

def __init__(self, *layer_sizes: int):
def __init__(
self,
layer_sizes: Sequence[int],
activation_cls: type = nn.ReLU,
*args,
**kwargs
):
"""
Initializes the MLP with the given layer sizes.
Initializes the MLP with specified layer sizes.
Parameters
----------
*layer_sizes: int
A variable number of positive integers specifying the size of each layer.
There must be at least two integers, representing the input and output layers.
layer_sizes : Sequence[int]
A sequence of positive integers indicating the size of each layer.
At least two integers are required, representing the input and output layers.
activation_cls : type
The class of the activation function to use between layers. Default is nn.ReLU.
*args
Additional arguments passed to the activation function.
**kwargs
Additional keyword arguments passed to the activation function.
Raises
------
AssertionError: If less than two layer sizes are provided.
AssertionError: If any layer size is not a positive integer.
AssertionError
If fewer than two layer sizes are provided or if any layer size is not a positive integer.
AssertionError
If activation_cls does not inherit from torch.nn.Module.
"""

assert (
len(layer_sizes) >= 2
), "Multilayer perceptron must have at least 2 layers"
assert all(
ls > 0 and isinstance(ls, int) for ls in layer_sizes
), "All layer sizes must be a positive integer"
), "All layer sizes must be positive integers"

assert issubclass(
activation_cls, nn.Module
), "activation_cls must inherit from torch.nn.Module"

layers = []
for i in range(len(layer_sizes) - 2):
layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()]
layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])]
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
layers.append(activation_cls(*args, **kwargs))
layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))

super().__init__(*layers)
6 changes: 6 additions & 0 deletions minerva/models/ssl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .lfr import LearnFromRandomnessModel, RepeatedModuleList

__all__ = [
"LearnFromRandomnessModel",
"RepeatedModuleList"
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.nn import Sequential, Conv2d, CrossEntropyLoss
from torchvision.transforms import Resize

from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel
from minerva.models.ssl.lfr import RepeatedModuleList, LearnFromRandomnessModel
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone


Expand Down Expand Up @@ -57,4 +57,3 @@ def __init__(self):
# Test the configure_optimizers method
optimizer = model.configure_optimizers()
assert optimizer is not None

0 comments on commit 0d0e5f8

Please sign in to comment.