Skip to content

Commit

Permalink
Type tests/models (#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
vectorvp authored Nov 28, 2024
1 parent 6957178 commit 369af83
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 89 deletions.
16 changes: 7 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,6 @@ exclude = '''(?x)(
tests/utils/benchmarking/test_linear_classifier.py |
tests/utils/benchmarking/test_metric_callback.py |
tests/utils/test_dist.py |
tests/models/test_ModelsSimSiam.py |
tests/models/modules/test_masked_autoencoder.py |
tests/models/test_ModelsSimCLR.py |
tests/models/test_ModelUtils.py |
tests/models/test_ModelsNNCLR.py |
tests/models/test_ModelsMoCo.py |
tests/models/test_ProjectionHeads.py |
tests/models/test_ModelsBYOL.py |
tests/conftest.py |
tests/api_workflow/test_api_workflow_selection.py |
tests/api_workflow/test_api_workflow_datasets.py |
Expand All @@ -325,7 +317,13 @@ exclude = '''(?x)(
lightly/models/barlowtwins.py |
lightly/models/nnclr.py |
lightly/models/simsiam.py |
lightly/models/byol.py )'''
lightly/models/byol.py |
# Let's not type deprecated models tests:
tests/models/test_ModelsSimSiam.py |
tests/models/test_ModelsSimCLR.py |
tests/models/test_ModelsNNCLR.py |
tests/models/test_ModelsMoCo.py |
tests/models/test_ModelsBYOL.py )'''

# Ignore imports from untyped modules.
[[tool.mypy.overrides]]
Expand Down
40 changes: 25 additions & 15 deletions tests/models/modules/test_masked_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import unittest

import torch
Expand All @@ -7,20 +9,24 @@
from lightly.utils import dependency

if dependency.torchvision_vit_available():
from torchvision.models.vision_transformer import VisionTransformer

from lightly.models.modules import MAEBackbone, MAEDecoder, MAEEncoder


@unittest.skipUnless(
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEEncoder(unittest.TestCase):
def _vit(self):
def _vit(self) -> VisionTransformer:
return torchvision.models.vision_transformer.vit_b_32(progress=False)

def test_from_vit(self):
def test_from_vit(self) -> None:
MAEEncoder.from_vit_encoder(self._vit().encoder)

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
vit = self._vit()
encoder = MAEEncoder.from_vit_encoder(vit.encoder).to(device)
Expand All @@ -42,25 +48,27 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(out, torch.inf)))

def test_forward(self):
self._test_forward(torch.device("cpu"))
def test_forward(self) -> None:
self._test_forward(device=torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))


@unittest.skipUnless(
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEBackbone(unittest.TestCase):
def _vit(self):
def _vit(self) -> torchvision.models.vision_transformer.VisionTransformer:
return torchvision.models.vision_transformer.vit_b_32(progress=False)

def test_from_vit(self):
def test_from_vit(self) -> None:
MAEBackbone.from_vit(self._vit())

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
vit = self._vit()
backbone = MAEBackbone.from_vit(vit).to(device)
Expand All @@ -80,11 +88,11 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf)))

def test_forward(self):
def test_forward(self) -> None:
self._test_forward(torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))

def test_images_to_tokens(self) -> None:
Expand All @@ -102,7 +110,7 @@ def test_images_to_tokens(self) -> None:
dependency.torchvision_vit_available(), "Torchvision ViT not available"
)
class TestMAEDecoder(unittest.TestCase):
def test_init(self):
def test_init(self) -> None:
MAEDecoder(
seq_length=50,
num_layers=2,
Expand All @@ -113,7 +121,9 @@ def test_init(self):
out_dim=3 * 32**2,
)

def _test_forward(self, device, batch_size=8, seed=0):
def _test_forward(
self, device: torch.device, batch_size: int = 8, seed: int = 0
) -> None:
torch.manual_seed(seed)
seq_length = 50
embed_input_dim = 128
Expand All @@ -137,9 +147,9 @@ def _test_forward(self, device, batch_size=8, seed=0):
# output must have reasonable numbers
self.assertTrue(torch.all(torch.not_equal(predictions, torch.inf)))

def test_forward(self):
def test_forward(self) -> None:
self._test_forward(torch.device("cpu"))

@unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.")
def test_forward_cuda(self):
def test_forward_cuda(self) -> None:
self._test_forward(torch.device("cuda"))
Loading

0 comments on commit 369af83

Please sign in to comment.