From 8b359a927912966cf8d862d9df5d21922d4d3e69 Mon Sep 17 00:00:00 2001 From: fernandoGubiMarques Date: Fri, 5 Jul 2024 15:32:30 +0000 Subject: [PATCH] implemented tests for LFR --- tests/models/nets/test_lfr.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tests/models/nets/test_lfr.py diff --git a/tests/models/nets/test_lfr.py b/tests/models/nets/test_lfr.py new file mode 100644 index 0000000..666e050 --- /dev/null +++ b/tests/models/nets/test_lfr.py @@ -0,0 +1,60 @@ +import torch +from torch.nn import Sequential, Conv2d, CrossEntropyLoss +from torchvision.transforms import Resize + +from minerva.models.nets.lfr import RepeatedModuleList, LearnFromRandomnessModel +from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone + + +def test_lfr(): + + ## Example class for projector + class Projector(Sequential): + def __init__(self): + super().__init__( + Conv2d(3, 16, 5, 2), + Conv2d(16, 64, 5, 2), + Conv2d(64, 16, 5, 2), + Resize((100, 50)), + ) + + ## Example class for predictor + class Predictor(Sequential): + def __init__(self): + super().__init__(Conv2d(2048, 16, 1), Resize((100, 50))) + + # Declare model + model = LearnFromRandomnessModel( + DeepLabV3Backbone(), + RepeatedModuleList(5, Projector), + RepeatedModuleList(5, Predictor), + CrossEntropyLoss(), + flatten=False + ) + + # Test the class instantiation + assert model is not None + + # # Test the forward method + input_shape = (2, 3, 701, 255) + expected_output_size = torch.Size([2, 5, 16, 100, 50]) + x = torch.rand(*input_shape) + + y_pred, y_proj = model(x) + assert ( + y_pred.shape == expected_output_size + ), f"Expected output shape {expected_output_size}, but got {y_pred.shape}" + + assert ( + y_proj.shape == expected_output_size + ), f"Expected output shape {expected_output_size}, but got {y_proj.shape}" + + # Test the _loss_func method + loss = model._loss_func(y_pred, y_proj) + assert loss is not None + # TODO: assert the loss result + + # Test the configure_optimizers method + optimizer = model.configure_optimizers() + assert optimizer is not None +