-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a base model for LFR #70
Merged
otavioon
merged 6 commits into
discovery-unicamp:main
from
fernandoGubiMarques:lfr-base
Jul 16, 2024
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import lightning as L | ||
import torch | ||
|
||
|
||
class RepeatedModuleList(torch.nn.ModuleList): | ||
""" | ||
A module list with the same module `cls`, instantiated `size` times. | ||
""" | ||
|
||
def __init__(self, size, cls, *args, **kwargs): | ||
""" | ||
Initializes the RepeatedModuleList with multiple instances of a given module class. | ||
|
||
Parameters | ||
---------- | ||
size: int | ||
The number of instances to create. | ||
cls: type | ||
The module class to instantiate. Must be a subclass of `torch.nn.Module`. | ||
*args: | ||
Positional arguments to pass to the module class constructor. | ||
**kwargs: | ||
Keyword arguments to pass to the module class constructor. | ||
|
||
Raises | ||
------ | ||
AssertionError: | ||
If `cls` is not a subclass of `torch.nn.Module`. | ||
|
||
Example | ||
------- | ||
>>> class SimpleModule(torch.nn.Module): | ||
>>> def __init__(self, in_features, out_features): | ||
>>> super().__init__() | ||
>>> self.linear = torch.nn.Linear(in_features, out_features) | ||
>>> | ||
>>> repeated_modules = RepeatedModuleList(3, SimpleModule, 10, 5) | ||
>>> print(repeated_modules) | ||
RepeatedModuleList( | ||
(0): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
(1): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
(2): SimpleModule( | ||
(linear): Linear(in_features=10, out_features=5, bias=True) | ||
) | ||
) | ||
""" | ||
|
||
assert issubclass( | ||
cls, torch.nn.Module | ||
), f"{cls} does not derive from torch.nn.Module" | ||
|
||
super().__init__([cls(*args, **kwargs) for _ in range(size)]) | ||
|
||
|
||
class LearnFromRandomnessModel(L.LightningModule): | ||
""" | ||
A PyTorch Lightning model for pretraining with the technique | ||
'Learning From Random Projectors'. | ||
|
||
References | ||
---------- | ||
Yi Sui, Tongzi Wu, Jesse C. Cresswell, Ga Wu, George Stein, Xiao Shi Huang, Xiaochen Zhang, Maksims Volkovs. | ||
"Self-supervised Representation Learning From Random Data Projectors", 2024 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone: torch.nn.Module, | ||
projectors: torch.nn.ModuleList, | ||
predictors: torch.nn.ModuleList, | ||
loss_fn: torch.nn.Module, | ||
learning_rate: float = 1e-3, | ||
flatten: bool = True, | ||
): | ||
""" | ||
Initialize the LFR_Model. | ||
|
||
Parameters | ||
---------- | ||
backbone : torch.nn.Module | ||
The backbone neural network for feature extraction. | ||
projectors : torch.nn.ModuleList | ||
A list of projector networks. | ||
predictors : torch.nn.ModuleList | ||
A list of predictor networks. | ||
loss_fn : torch.nn.Module | ||
The loss function to optimize. | ||
learning_rate : Optional[float] | ||
GabrielBG0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The learning rate for the optimizer, by default 1e-3. | ||
flatten : Optional[bool] | ||
Whether to flatten the input tensor or not, by default True. | ||
""" | ||
super().__init__() | ||
self.backbone = backbone | ||
self.projectors = projectors | ||
self.predictors = predictors | ||
self.loss_fn = loss_fn | ||
self.learning_rate = learning_rate | ||
self.flatten = flatten | ||
|
||
for param in self.projectors.parameters(): | ||
param.requires_grad = False | ||
|
||
for proj in self.projectors: | ||
proj.eval() | ||
|
||
def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculate the loss between the output and the input data. | ||
|
||
Parameters | ||
---------- | ||
y_hat : torch.Tensor | ||
The output data from the forward pass. | ||
y : torch.Tensor | ||
The input data/label. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The loss value. | ||
""" | ||
loss = self.loss_fn(y_hat, y) | ||
return loss | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass through the network. | ||
|
||
Parameters | ||
---------- | ||
x : torch.Tensor | ||
The input data. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The predicted output and projected input. | ||
""" | ||
z: torch.Tensor = self.backbone(x) | ||
|
||
if self.flatten: | ||
z = z.view(z.size(0), -1) | ||
x = x.view(x.size(0), -1) | ||
|
||
y_pred = torch.stack([predictor(z) for predictor in self.predictors], 1) | ||
y_proj = torch.stack([projector(x) for projector in self.projectors], 1) | ||
|
||
return y_pred, y_proj | ||
|
||
def _single_step( | ||
self, batch: torch.Tensor, batch_idx: int, step_name: str | ||
) -> torch.Tensor: | ||
""" | ||
Perform a single training/validation/test step. | ||
|
||
Parameters | ||
---------- | ||
batch : torch.Tensor | ||
The input batch of data. | ||
batch_idx : int | ||
The index of the batch. | ||
step_name : str | ||
The name of the step (train, val, test). | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
The loss value for the batch. | ||
""" | ||
x = batch | ||
y_pred, y_proj = self.forward(x) | ||
loss = self._loss_func(y_pred, y_proj) | ||
self.log( | ||
f"{step_name}_loss", | ||
loss, | ||
on_step=False, | ||
on_epoch=True, | ||
prog_bar=True, | ||
logger=True, | ||
) | ||
|
||
return loss | ||
|
||
def training_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="train") | ||
|
||
def validation_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="val") | ||
|
||
def test_step(self, batch: torch.Tensor, batch_idx: int): | ||
return self._single_step(batch, batch_idx, step_name="test") | ||
|
||
def configure_optimizers(self): | ||
return torch.optim.Adam( | ||
self.parameters(), | ||
lr=self.learning_rate, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from torch import nn | ||
|
||
class MLP(nn.Sequential): | ||
""" | ||
A multilayer perceptron (MLP) implemented as a subclass of nn.Sequential. | ||
|
||
The MLP consists of a series of linear layers interleaved with ReLU activation functions, | ||
except for the last layer which is purely linear. | ||
|
||
Example | ||
------- | ||
|
||
>>> mlp = MLP(10, 20, 30, 40) | ||
>>> print(mlp) | ||
MLP( | ||
(0): Linear(in_features=10, out_features=20, bias=True) | ||
(1): ReLU() | ||
(2): Linear(in_features=20, out_features=30, bias=True) | ||
(3): ReLU() | ||
(4): Linear(in_features=30, out_features=40, bias=True) | ||
) | ||
""" | ||
|
||
def __init__(self, *layer_sizes): | ||
GabrielBG0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Initializes the MLP with the given layer sizes. | ||
|
||
Parameters | ||
---------- | ||
*layer_sizes: int | ||
A variable number of positive integers specifying the size of each layer. | ||
There must be at least two integers, representing the input and output layers. | ||
|
||
Raises | ||
------ | ||
AssertionError: If less than two layer sizes are provided. | ||
|
||
AssertionError: If any layer size is not a positive integer. | ||
""" | ||
assert ( | ||
len(layer_sizes) >= 2 | ||
), "Multilayer perceptron must have at least 2 layers" | ||
assert all( | ||
ls > 0 and isinstance(ls, int) for ls in layer_sizes | ||
), "All layer sizes must be a positive integer" | ||
|
||
layers = [] | ||
for i in range(len(layer_sizes) - 2): | ||
layers += [nn.Linear(layer_sizes[i], layer_sizes[i + 1]), nn.ReLU()] | ||
layers += [nn.Linear(layer_sizes[-2], layer_sizes[-1])] | ||
|
||
super().__init__(*layers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
from torch.nn import Sequential, Conv2d, CrossEntropyLoss | ||
from torchvision.transforms import Resize | ||
|
||
from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel | ||
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone | ||
|
||
|
||
def test_lfr(): | ||
|
||
## Example class for projector | ||
class Projector(Sequential): | ||
def __init__(self): | ||
super().__init__( | ||
Conv2d(3, 16, 5, 2), | ||
Conv2d(16, 64, 5, 2), | ||
Conv2d(64, 16, 5, 2), | ||
Resize((100, 50)), | ||
) | ||
|
||
## Example class for predictor | ||
class Predictor(Sequential): | ||
def __init__(self): | ||
super().__init__(Conv2d(2048, 16, 1), Resize((100, 50))) | ||
|
||
# Declare model | ||
model = LearnFromRandomnessModel( | ||
DeepLabV3Backbone(), | ||
RepeatedModuleList(5, Projector), | ||
RepeatedModuleList(5, Predictor), | ||
CrossEntropyLoss(), | ||
flatten=False | ||
) | ||
|
||
# Test the class instantiation | ||
assert model is not None | ||
|
||
# # Test the forward method | ||
input_shape = (2, 3, 701, 255) | ||
expected_output_size = torch.Size([2, 5, 16, 100, 50]) | ||
x = torch.rand(*input_shape) | ||
|
||
y_pred, y_proj = model(x) | ||
assert ( | ||
y_pred.shape == expected_output_size | ||
), f"Expected output shape {expected_output_size}, but got {y_pred.shape}" | ||
|
||
assert ( | ||
y_proj.shape == expected_output_size | ||
), f"Expected output shape {expected_output_size}, but got {y_proj.shape}" | ||
|
||
# Test the _loss_func method | ||
loss = model._loss_func(y_pred, y_proj) | ||
assert loss is not None | ||
# TODO: assert the loss result | ||
|
||
# Test the configure_optimizers method | ||
optimizer = model.configure_optimizers() | ||
assert optimizer is not None | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define types for
__init__
parameters