Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a base model for LFR #70

Merged
merged 6 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion minerva/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@

# SSL Models

...
| **Model** | **Authors** | **Task** | **Type** | **Input Shape** | **Python Class** | **Observations** |
|-----------------------------------------|---------------|----------|----------|:---------------:|:--------------------------------------------:|-------------------|
otavioon marked this conversation as resolved.
Show resolved Hide resolved
| [LFR](https://arxiv.org/abs/2310.07756) | Yi Sui et al. | Any | Any | Any | minerva.models.nets.LearnFromRandomnessModel | |
5 changes: 5 additions & 0 deletions minerva/models/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from .image.setr import SETR_PUP
from .image.unet import UNet
from .image.wisenet import WiseNet
from .mlp import MLP
from .lfr import LearnFromRandomnessModel, RepeatedModuleList

__all__ = [
"SimpleSupervisedModel",
"DeepLabV3",
"SETR_PUP",
"UNet",
"WiseNet",
"MLP",
"LearnFromRandomnessModel",
"RepeatedModuleList"
]
208 changes: 208 additions & 0 deletions minerva/models/nets/lfr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import lightning as L
import torch


class RepeatedModuleList(torch.nn.ModuleList):
"""
A module list with the same module `cls`, instantiated `size` times.
"""

def __init__(
self,
size: int,
cls: type,
*args,
**kwargs
):
"""
Initializes the RepeatedModuleList with multiple instances of a given module class.

Parameters
----------
size: int
The number of instances to create.
cls: type
The module class to instantiate. Must be a subclass of `torch.nn.Module`.
*args:
Positional arguments to pass to the module class constructor.
**kwargs:
Keyword arguments to pass to the module class constructor.

Raises
------
AssertionError:
If `cls` is not a subclass of `torch.nn.Module`.

Example
-------
>>> class SimpleModule(torch.nn.Module):
>>> def __init__(self, in_features, out_features):
>>> super().__init__()
>>> self.linear = torch.nn.Linear(in_features, out_features)
>>>
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5)
>>> print(repeated_modules)
RepeatedModuleList(
(0): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(1): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
(2): SimpleModule(
(linear): Linear(in_features=10, out_features=5, bias=True)
)
)
"""

assert issubclass(
cls, torch.nn.Module
), f"{cls} does not derive from torch.nn.Module"

super().__init__([cls(*args, **kwargs) for _ in range(size)])


class LearnFromRandomnessModel(L.LightningModule):
"""
A PyTorch Lightning model for pretraining with the technique
'Learning From Random Projectors'.

References
----------
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs.
"Self-supervised Representation Learning From Random Data Projectors", 2024
"""

def __init__(
self,
backbone: torch.nn.Module,
projectors: torch.nn.ModuleList,
predictors: torch.nn.ModuleList,
loss_fn: torch.nn.Module,
learning_rate: float = 1e-3,
flatten: bool = True,
):
"""
Initialize the LFR_Model.

Parameters
----------
backbone: torch.nn.Module
The backbone neural network for feature extraction.
projectors: torch.nn.ModuleList
A list of projector networks.
predictors: torch.nn.ModuleList
A list of predictor networks.
loss_fn: torch.nn.Module
The loss function to optimize.
learning_rate: float
The learning rate for the optimizer, by default 1e-3.
flatten: bool
Whether to flatten the input tensor or not, by default True.
"""
super().__init__()
self.backbone = backbone
self.projectors = projectors
self.predictors = predictors
self.loss_fn = loss_fn
self.learning_rate = learning_rate
self.flatten = flatten

for param in self.projectors.parameters():
param.requires_grad = False

for proj in self.projectors:
proj.eval()

def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Calculate the loss between the output and the input data.

Parameters
----------
y_hat : torch.Tensor
The output data from the forward pass.
y : torch.Tensor
The input data/label.

Returns
-------
torch.Tensor
The loss value.
"""
loss = self.loss_fn(y_hat, y)
return loss

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network.

Parameters
----------
x : torch.Tensor
The input data.

Returns
-------
torch.Tensor
The predicted output and projected input.
"""
z: torch.Tensor = self.backbone(x)

if self.flatten:
z = z.view(z.size(0), -1)
x = x.view(x.size(0), -1)

y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1)
y_proj = torch.stack([projector(x) for projector in self.projectors], 1)

return y_pred, y_proj

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
"""
Perform a single training/validation/test step.

Parameters
----------
batch : torch.Tensor
The input batch of data.
batch_idx : int
The index of the batch.
step_name : str
The name of the step (train, val, test).

Returns
-------
torch.Tensor
The loss value for the batch.
"""
x = batch
y_pred, y_proj = self.forward(x)
loss = self._loss_func(y_pred, y_proj)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="train")

def validation_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="val")

def test_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="test")

def configure_optimizers(self):
return torch.optim.Adam(
self.parameters(),
lr=self.learning_rate,
)
52 changes: 52 additions & 0 deletions minerva/models/nets/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch import nn

class MLP(nn.Sequential):
"""
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential.

The MLP consists of a series of linear layers interleaved with ReLU activation functions,
except for the last layer which is purely linear.

Example
-------

>>> mlp = MLP(10, 20, 30, 40)
>>> print(mlp)
MLP(
(0): Linear(in_features=10, out_features=20, bias=True)
(1): ReLU()
(2): Linear(in_features=20, out_features=30, bias=True)
(3): ReLU()
(4): Linear(in_features=30, out_features=40, bias=True)
)
"""

def __init__(self, *layer_sizes: int):
otavioon marked this conversation as resolved.
Show resolved Hide resolved
"""
Initializes the MLP with the given layer sizes.

Parameters
----------
*layer_sizes: int
A variable number of positive integers specifying the size of each layer.
There must be at least two integers, representing the input and output layers.

Raises
------
AssertionError: If less than two layer sizes are provided.

AssertionError: If any layer size is not a positive integer.
"""
assert (
len(layer_sizes) >= 2
), "Multilayer perceptron must have at least 2 layers"
assert all(
ls > 0 and isinstance(ls, int) for ls in layer_sizes
), "All layer sizes must be a positive integer"

layers = []
for i in range(len(layer_sizes) - 2):
layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()]
layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])]

super().__init__(*layers)
60 changes: 60 additions & 0 deletions tests/models/nets/test_lfr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch.nn import Sequential, Conv2d, CrossEntropyLoss
from torchvision.transforms import Resize

from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone


def test_lfr():

## Example class for projector
class Projector(Sequential):
def __init__(self):
super().__init__(
Conv2d(3, 16, 5, 2),
Conv2d(16, 64, 5, 2),
Conv2d(64, 16, 5, 2),
Resize((100, 50)),
)

## Example class for predictor
class Predictor(Sequential):
def __init__(self):
super().__init__(Conv2d(2048, 16, 1), Resize((100, 50)))

# Declare model
model = LearnFromRandomnessModel(
DeepLabV3Backbone(),
RepeatedModuleList(5, Projector),
RepeatedModuleList(5, Predictor),
CrossEntropyLoss(),
flatten=False
)

# Test the class instantiation
assert model is not None

# # Test the forward method
input_shape = (2, 3, 701, 255)
expected_output_size = torch.Size([2, 5, 16, 100, 50])
x = torch.rand(*input_shape)

y_pred, y_proj = model(x)
assert (
y_pred.shape == expected_output_size
), f"Expected output shape {expected_output_size}, but got {y_pred.shape}"

assert (
y_proj.shape == expected_output_size
), f"Expected output shape {expected_output_size}, but got {y_proj.shape}"

# Test the _loss_func method
loss = model._loss_func(y_pred, y_proj)
assert loss is not None
# TODO: assert the loss result

# Test the configure_optimizers method
optimizer = model.configure_optimizers()
assert optimizer is not None

Loading