diff --git a/.github/workflows/test_code_format.yml b/.github/workflows/test_code_format.yml index 504756df3..61eb862fb 100644 --- a/.github/workflows/test_code_format.yml +++ b/.github/workflows/test_code_format.yml @@ -9,7 +9,6 @@ jobs: test: name: Check runs-on: ubuntu-latest - steps: - name: Checkout Code uses: actions/checkout@v3 @@ -21,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.7" - uses: actions/cache@v2 with: path: ${{ env.pythonLocation }} @@ -30,5 +29,7 @@ jobs: run: pip install -e '.[all]' - name: Run Format Check run: | - export LIGHTLY_SERVER_LOCATION="localhost:-1" make format-check + - name: Run Type Check + run: | + make type-check diff --git a/Makefile b/Makefile index ead2c748d..fcb2ab17b 100644 --- a/Makefile +++ b/Makefile @@ -63,8 +63,15 @@ test: test-fast: pytest tests -# run format checks and tests -all-checks: format-check test +## check typing +type-check: + mypy lightly tests + +## run format checks +static-checks: format-check type-check + +## run format checks and tests +all-checks: static-checks test ## build source and wheel package dist: clean diff --git a/lightly/__init__.py b/lightly/__init__.py index 28a5a0a82..3f46ed904 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -77,6 +77,7 @@ __name__ = "lightly" __version__ = "1.4.17" + import os # see if torchvision vision transformer is available @@ -91,6 +92,7 @@ ): _torchvision_vit_available = False + if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False": os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True" import multiprocessing diff --git a/lightly/api/utils.py b/lightly/api/utils.py index 5960fa7f6..d80bb2d67 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -24,7 +24,7 @@ RETRY_MAX_RETRIES = 5 -def retry(func, *args, **kwargs): +def retry(func, *args, **kwargs): # type: ignore """Repeats a function until it completes successfully or fails too often. Args: diff --git a/lightly/embedding/__init__.py b/lightly/embedding/__init__.py index f2bb98c2d..3ee7a5480 100644 --- a/lightly/embedding/__init__.py +++ b/lightly/embedding/__init__.py @@ -8,5 +8,6 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved + from lightly.embedding._base import BaseEmbedding from lightly.embedding.embedding import SelfSupervisedEmbedding diff --git a/lightly/loss/__init__.py b/lightly/loss/__init__.py index 706045fd1..613707abe 100644 --- a/lightly/loss/__init__.py +++ b/lightly/loss/__init__.py @@ -2,7 +2,6 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved - from lightly.loss.barlow_twins_loss import BarlowTwinsLoss from lightly.loss.dcl_loss import DCLLoss, DCLWLoss from lightly.loss.dino_loss import DINOLoss diff --git a/lightly/transforms/byol_transform.py b/lightly/transforms/byol_transform.py index d3eff6a5e..26d9f1360 100644 --- a/lightly/transforms/byol_transform.py +++ b/lightly/transforms/byol_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -31,7 +31,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -67,7 +67,8 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed class BYOLView2Transform: @@ -90,7 +91,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -126,7 +127,8 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed class BYOLTransform(MultiViewTransform): diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py index de6a18657..38e719fbd 100644 --- a/lightly/transforms/dino_transform.py +++ b/lightly/transforms/dino_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import PIL import torchvision.transforms as T @@ -117,7 +117,7 @@ def __init__( kernel_scale: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), solarization_prob: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): # first global crop global_transform_0 = DINOViewTransform( @@ -213,7 +213,7 @@ def __init__( kernel_scale: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), solarization_prob: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transform = [ T.RandomResizedCrop( @@ -262,4 +262,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/fast_siam_transform.py b/lightly/transforms/fast_siam_transform.py index 7b560591f..c70b77de1 100644 --- a/lightly/transforms/fast_siam_transform.py +++ b/lightly/transforms/fast_siam_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.simsiam_transform import SimSiamViewTransform @@ -89,7 +89,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = [ SimSiamViewTransform( diff --git a/lightly/transforms/gaussian_blur.py b/lightly/transforms/gaussian_blur.py index 180e6ff08..3f8f9e2d7 100644 --- a/lightly/transforms/gaussian_blur.py +++ b/lightly/transforms/gaussian_blur.py @@ -6,6 +6,7 @@ import numpy as np from PIL import ImageFilter +from PIL.Image import Image class GaussianBlur: @@ -47,7 +48,7 @@ def __init__( self.prob = prob self.sigmas = sigmas - def __call__(self, sample): + def __call__(self, sample: Image) -> Image: """Blurs the image with a given probability. Args: diff --git a/lightly/transforms/ijepa_transform.py b/lightly/transforms/ijepa_transform.py index 321dba66a..6bd002446 100644 --- a/lightly/transforms/ijepa_transform.py +++ b/lightly/transforms/ijepa_transform.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -30,7 +30,7 @@ def __init__( self, input_size: Union[int, Tuple[int, int]] = 224, min_scale: float = 0.2, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): transforms = [ T.RandomResizedCrop( @@ -55,4 +55,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/image_grid_transform.py b/lightly/transforms/image_grid_transform.py index 30854c910..1822761a6 100644 --- a/lightly/transforms/image_grid_transform.py +++ b/lightly/transforms/image_grid_transform.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Sequence, Union import torchvision.transforms as T from PIL.Image import Image @@ -19,7 +19,7 @@ class ImageGridTransform: grids. """ - def __init__(self, transforms): + def __init__(self, transforms: Sequence[T.Compose]): self.transforms = transforms def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]: diff --git a/lightly/transforms/jigsaw.py b/lightly/transforms/jigsaw.py index db9cbc37b..adebb808b 100644 --- a/lightly/transforms/jigsaw.py +++ b/lightly/transforms/jigsaw.py @@ -1,10 +1,14 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved +from typing import List + import numpy as np import torch -from PIL import Image -from torchvision import transforms +from PIL import Image as Image +from PIL.Image import Image as PILImage +from torch import Tensor +from torchvision import transforms as T class Jigsaw(object): @@ -34,7 +38,11 @@ class Jigsaw(object): """ def __init__( - self, n_grid=3, img_size=255, crop_size=64, transform=transforms.ToTensor() + self, + n_grid: int = 3, + img_size: int = 255, + crop_size: int = 64, + transform: T.Compose = T.ToTensor(), ): self.n_grid = n_grid self.img_size = img_size @@ -47,7 +55,7 @@ def __init__( self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,)) self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,)) - def __call__(self, img): + def __call__(self, img: PILImage) -> Tensor: """Performs the Jigsaw augmentation Args: img: @@ -59,7 +67,7 @@ def __call__(self, img): r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid) img = np.asarray(img, np.uint8) - crops = [] + crops: List[PILImage] = [] for i in range(self.n_grid * self.n_grid): crops.append( img[ @@ -68,7 +76,9 @@ def __call__(self, img): :, ] ) - crops = [Image.fromarray(crop) for crop in crops] - crops = torch.stack([self.transform(crop) for crop in crops]) - crops = crops[np.random.permutation(self.n_grid**2)] - return crops + crop_images = [Image.fromarray(crop) for crop in crops] + crop_tensors: Tensor = torch.stack( + [self.transform(crop) for crop in crop_images] + ) + permutation: List[int] = np.random.permutation(self.n_grid**2).tolist() + return crop_tensors[permutation] diff --git a/lightly/transforms/mae_transform.py b/lightly/transforms/mae_transform.py index 3176f084e..50f9dd9f7 100644 --- a/lightly/transforms/mae_transform.py +++ b/lightly/transforms/mae_transform.py @@ -1,10 +1,9 @@ -from typing import List, Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T from PIL.Image import Image from torch import Tensor -from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.utils import IMAGENET_NORMALIZE @@ -37,7 +36,7 @@ def __init__( self, input_size: Union[int, Tuple[int, int]] = 224, min_scale: float = 0.2, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): transforms = [ T.RandomResizedCrop( diff --git a/lightly/transforms/moco_transform.py b/lightly/transforms/moco_transform.py index 8ec8ade55..3f5f728fe 100644 --- a/lightly/transforms/moco_transform.py +++ b/lightly/transforms/moco_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from lightly.transforms.simclr_transform import SimCLRTransform from lightly.transforms.utils import IMAGENET_NORMALIZE @@ -83,7 +83,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): super().__init__( input_size=input_size, diff --git a/lightly/transforms/msn_transform.py b/lightly/transforms/msn_transform.py index d07130732..b8a78ac1c 100644 --- a/lightly/transforms/msn_transform.py +++ b/lightly/transforms/msn_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -96,7 +96,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, vf_prob: float = 0.0, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): random_view_transform = MSNViewTransform( crop_size=random_size, @@ -150,7 +150,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, vf_prob: float = 0.0, - normalize: dict = IMAGENET_NORMALIZE, + normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -183,4 +183,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/multi_crop_transform.py b/lightly/transforms/multi_crop_transform.py index 307a507c6..2e546b22f 100644 --- a/lightly/transforms/multi_crop_transform.py +++ b/lightly/transforms/multi_crop_transform.py @@ -34,11 +34,11 @@ class MultiCropTranform(MultiViewTransform): def __init__( self, - crop_sizes: Tuple[int], - crop_counts: Tuple[int], - crop_min_scales: Tuple[float], - crop_max_scales: Tuple[float], - transforms, + crop_sizes: Tuple[int, ...], + crop_counts: Tuple[int, ...], + crop_min_scales: Tuple[float, ...], + crop_max_scales: Tuple[float, ...], + transforms: T.Compose, ): if len(crop_sizes) != len(crop_counts): raise ValueError( diff --git a/lightly/transforms/multi_view_transform.py b/lightly/transforms/multi_view_transform.py index a00f4fc00..62c9f2cb5 100644 --- a/lightly/transforms/multi_view_transform.py +++ b/lightly/transforms/multi_view_transform.py @@ -1,7 +1,8 @@ -from typing import List, Union +from typing import List, Sequence, Union from PIL.Image import Image from torch import Tensor +from torchvision import transforms as T class MultiViewTransform: @@ -13,7 +14,7 @@ class MultiViewTransform: """ - def __init__(self, transforms): + def __init__(self, transforms: Sequence[T.Compose]): self.transforms = transforms def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]: diff --git a/lightly/transforms/pirl_transform.py b/lightly/transforms/pirl_transform.py index b67e451c7..1d5ec7d57 100644 --- a/lightly/transforms/pirl_transform.py +++ b/lightly/transforms/pirl_transform.py @@ -1,8 +1,6 @@ -from typing import Tuple, Union +from typing import Dict, List, Tuple, Union import torchvision.transforms as T -from PIL.Image import Image -from torch import Tensor from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.multi_view_transform import MultiViewTransform @@ -71,7 +69,7 @@ def __init__( random_gray_scale: float = 0.2, hf_prob: float = 0.5, n_grid: int = 3, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): if isinstance(input_size, tuple): input_size_ = max(input_size) @@ -79,13 +77,17 @@ def __init__( input_size_ = input_size # Cropping and normalisation for non-transformed image - no_augment = T.Compose( - [ - T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), - T.ToTensor(), - T.Normalize(mean=normalize["mean"], std=normalize["std"]), - ] - ) + transforms_no_augment = [ + T.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0)), + T.ToTensor(), + ] + + if normalize is not None: + transforms_no_augment.append( + T.Normalize(mean=normalize["mean"], std=normalize["std"]) + ) + + no_augment = T.Compose(transforms_no_augment) color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -95,21 +97,21 @@ def __init__( ) # Transform for transformed jigsaw image - transform = [ + transforms = [ T.RandomHorizontalFlip(p=hf_prob), T.RandomApply([color_jitter], p=cj_prob), T.RandomGrayscale(p=random_gray_scale), T.ToTensor(), ] - if normalize: - transform += [T.Normalize(mean=normalize["mean"], std=normalize["std"])] + if normalize is not None: + transforms.append(T.Normalize(mean=normalize["mean"], std=normalize["std"])) jigsaw = Jigsaw( n_grid=n_grid, img_size=input_size_, crop_size=int(input_size_ // n_grid), - transform=T.Compose(transform), + transform=T.Compose(transforms), ) super().__init__([no_augment, jigsaw]) diff --git a/lightly/transforms/py.typed b/lightly/transforms/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/lightly/transforms/random_crop_and_flip_with_grid.py b/lightly/transforms/random_crop_and_flip_with_grid.py index 6fbb92a33..60a66226f 100644 --- a/lightly/transforms/random_crop_and_flip_with_grid.py +++ b/lightly/transforms/random_crop_and_flip_with_grid.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Tuple import torch import torchvision.transforms as T @@ -28,7 +28,7 @@ class Location: vertical_flip: bool = False -class RandomResizedCropWithLocation(T.RandomResizedCrop): +class RandomResizedCropWithLocation(T.RandomResizedCrop): # type: ignore[misc] # Class cannot subclass "RandomResizedCrop" (has type "Any") """ Do a random resized crop and return both the resulting image and the location. See base class. @@ -59,7 +59,7 @@ def forward(self, img: Image.Image) -> Tuple[Image.Image, Location]: return img, location -class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip): +class RandomHorizontalFlipWithLocation(T.RandomHorizontalFlip): # type: ignore[misc] # Class cannot subclass "RandomHorizontalFlip" (has type "Any") """See base class.""" def forward( @@ -84,7 +84,7 @@ def forward( return img, location -class RandomVerticalFlipWithLocation(T.RandomVerticalFlip): +class RandomVerticalFlipWithLocation(T.RandomVerticalFlip): # type: ignore[misc] # Class cannot subclass "RandomVerticalFlip" (has type "Any") """See base class.""" def forward( diff --git a/lightly/transforms/simclr_transform.py b/lightly/transforms/simclr_transform.py index 8c39591c7..d975b0771 100644 --- a/lightly/transforms/simclr_transform.py +++ b/lightly/transforms/simclr_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -102,7 +102,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = SimCLRViewTransform( input_size=input_size, @@ -145,7 +145,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -180,4 +180,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/simsiam_transform.py b/lightly/transforms/simsiam_transform.py index 379d55242..4f2607480 100644 --- a/lightly/transforms/simsiam_transform.py +++ b/lightly/transforms/simsiam_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -91,7 +91,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = SimSiamViewTransform( input_size=input_size, @@ -134,7 +134,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -169,4 +169,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/smog_transform.py b/lightly/transforms/smog_transform.py index 6f8958377..69d8de64c 100644 --- a/lightly/transforms/smog_transform.py +++ b/lightly/transforms/smog_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -88,7 +88,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = [] for i in range(len(crop_sizes)): @@ -137,7 +137,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -175,4 +175,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/solarize.py b/lightly/transforms/solarize.py index bbd460899..3e640a404 100644 --- a/lightly/transforms/solarize.py +++ b/lightly/transforms/solarize.py @@ -3,6 +3,7 @@ import numpy as np from PIL import ImageOps +from PIL.Image import Image as PILImage class RandomSolarization(object): @@ -22,7 +23,7 @@ def __init__(self, prob: float = 0.5, threshold: int = 128): self.prob = prob self.threshold = threshold - def __call__(self, sample): + def __call__(self, sample: PILImage) -> PILImage: """Solarizes the given input image Args: diff --git a/lightly/transforms/swav_transform.py b/lightly/transforms/swav_transform.py index 2cbabf94f..f7000f945 100644 --- a/lightly/transforms/swav_transform.py +++ b/lightly/transforms/swav_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -96,7 +96,7 @@ def __init__( gaussian_blur: float = 0.5, kernel_size: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): transforms = SwaVViewTransform( hf_prob=hf_prob, @@ -142,7 +142,7 @@ def __init__( gaussian_blur: float = 0.5, kernel_size: Optional[float] = None, sigmas: Tuple[float, float] = (0.1, 2), - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/vicreg_transform.py b/lightly/transforms/vicreg_transform.py index e2234f6c4..c7cc0b270 100644 --- a/lightly/transforms/vicreg_transform.py +++ b/lightly/transforms/vicreg_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -97,7 +97,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): view_transform = VICRegViewTransform( input_size=input_size, @@ -142,7 +142,7 @@ def __init__( hf_prob: float = 0.5, rr_prob: float = 0.0, rr_degrees: Union[None, float, Tuple[float, float]] = None, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -178,4 +178,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/transforms/vicregl_transform.py b/lightly/transforms/vicregl_transform.py index 70baf139d..335c03c34 100644 --- a/lightly/transforms/vicregl_transform.py +++ b/lightly/transforms/vicregl_transform.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torchvision.transforms as T from PIL.Image import Image @@ -123,7 +123,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): global_transform = ( RandomResizedCropAndFlip( @@ -189,7 +189,7 @@ def __init__( cj_sat: float = 0.4, cj_hue: float = 0.2, random_gray_scale: float = 0.2, - normalize: Union[None, dict] = IMAGENET_NORMALIZE, + normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE, ): color_jitter = T.ColorJitter( brightness=cj_strength * cj_bright, @@ -223,4 +223,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Returns: The transformed image. """ - return self.transform(image) + transformed: Tensor = self.transform(image) + return transformed diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py index fd2d7097b..44dc32aed 100644 --- a/lightly/utils/hipify.py +++ b/lightly/utils/hipify.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Type +from typing import Type, Union class bcolors: @@ -14,15 +14,19 @@ class bcolors: UNDERLINE = "\033[4m" -def _custom_formatwarning(msg, *args, **kwargs): +def _custom_formatwarning( + message: Union[str, Warning], + category: Type[Warning], + filename: str, + lineno: int, + line: Union[str, None] = None, +) -> str: # ignore everything except the message - return f"{bcolors.WARNING}{msg}{bcolors.WARNING}\n" + return f"{bcolors.WARNING}{message}{bcolors.WARNING}\n" -def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning): +def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning) -> None: old_format = copy.copy(warnings.formatwarning) - warnings.formatwarning = _custom_formatwarning warnings.warn(message, warning_class) - warnings.formatwarning = old_format diff --git a/lightly/utils/io.py b/lightly/utils/io.py index 00ca63a1c..556d35ef1 100644 --- a/lightly/utils/io.py +++ b/lightly/utils/io.py @@ -7,9 +7,10 @@ import json import re from itertools import compress -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union import numpy as np +from numpy.typing import NDArray INVALID_FILENAME_CHARACTERS = [","] @@ -22,7 +23,7 @@ def _is_valid_filename(filename: str) -> bool: return True -def check_filenames(filenames: List[str]): +def check_filenames(filenames: List[str]) -> None: """Raises an error if one of the filenames is misformatted Args: @@ -35,7 +36,7 @@ def check_filenames(filenames: List[str]): raise ValueError(f"Invalid filename(s): {invalid_filenames}") -def check_embeddings(path: str, remove_additional_columns: bool = False): +def check_embeddings(path: str, remove_additional_columns: bool = False) -> None: """Raises an error if the embeddings csv file has not the correct format Use this check whenever you want to upload an embedding to the Lightly @@ -118,8 +119,8 @@ def check_embeddings(path: str, remove_additional_columns: bool = False): def save_embeddings( - path: str, embeddings: np.ndarray, labels: List[int], filenames: List[str] -): + path: str, embeddings: NDArray[np.float64], labels: List[int], filenames: List[str] +) -> None: """Saves embeddings in a csv file in a Lightly compatible format. Creates a csv file at the location specified by path and saves embeddings, @@ -167,7 +168,7 @@ def save_embeddings( writer.writerow([filename] + list(embedding) + [str(label)]) -def load_embeddings(path: str): +def load_embeddings(path: str) -> Tuple[NDArray[np.float64], List[int], List[str]]: """Loads embeddings from a csv file in a Lightly compatible format. Args: @@ -204,13 +205,13 @@ def load_embeddings(path: str): check_filenames(filenames) - embeddings = np.array(embeddings).astype(np.float32) - return embeddings, labels, filenames + embedding_array = np.array(embeddings).astype(np.float64) + return embedding_array, labels, filenames def load_embeddings_as_dict( path: str, embedding_name: str = "default", return_all: bool = False -): +) -> Union[Any, Tuple[Any, NDArray[np.float64], List[int], List[str]]]: """Loads embeddings from csv and store it in a dictionary for transfer. Loads embeddings to a dictionary which can be serialized and sent to the @@ -245,10 +246,13 @@ def load_embeddings_as_dict( embeddings, labels, filenames = load_embeddings(path) # build dictionary - data = {"embeddingName": embedding_name, "embeddings": []} - for embedding, filename, label in zip(embeddings, filenames, labels): - item = {"fileName": filename, "value": embedding.tolist(), "label": label} - data["embeddings"].append(item) + data = { + "embeddingName": embedding_name, + "embeddings": [ + {"fileName": filename, "value": embedding.tolist(), "label": label} + for embedding, filename, label in zip(embeddings, filenames, labels) + ], + } # return embeddings along with dictionary if return_all: @@ -270,7 +274,9 @@ class COCO_ANNOTATION_KEYS: custom_metadata_image_id: str = "image_id" -def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): +def format_custom_metadata( + custom_metadata: List[Tuple[str, Any]] +) -> Dict[str, List[Any]]: """Transforms custom metadata into a format which can be handled by Lightly. Args: @@ -293,7 +299,7 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): >>> > } """ - formatted = { + formatted: Dict[str, List[Any]] = { COCO_ANNOTATION_KEYS.images: [], COCO_ANNOTATION_KEYS.custom_metadata: [], } @@ -315,7 +321,7 @@ def format_custom_metadata(custom_metadata: List[Tuple[str, Dict]]): return formatted -def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]): +def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Any]]) -> None: """Saves custom metadata in a .json. Args: @@ -333,7 +339,7 @@ def save_custom_metadata(path: str, custom_metadata: List[Tuple[str, Dict]]): def save_tasks( path: str, tasks: List[str], -): +) -> None: """Saves a list of prediction task names in the right format. Args: @@ -347,7 +353,7 @@ def save_tasks( json.dump(tasks, f) -def save_schema(path: str, task_type: str, ids: List[int], names: List[str]): +def save_schema(path: str, task_type: str, ids: List[int], names: List[str]) -> None: """Saves a prediction schema in the right format. Args: diff --git a/lightly/utils/lars.py b/lightly/utils/lars.py index af14ff028..063149d36 100644 --- a/lightly/utils/lars.py +++ b/lightly/utils/lars.py @@ -1,5 +1,8 @@ +from typing import Any, Callable, Dict, Optional, Union + import torch -from torch.optim.optimizer import Optimizer, required +from torch import Tensor +from torch.optim.optimizer import Optimizer, required # type: ignore[attr-defined] class LARS(Optimizer): @@ -65,7 +68,7 @@ class LARS(Optimizer): def __init__( self, - params, + params: Any, lr: float = required, momentum: float = 0, dampening: float = 0, @@ -95,14 +98,14 @@ def __init__( super().__init__(params, defaults) - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) @torch.no_grad() - def step(self, closure=None): + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: """Performs a single optimization step. Args: diff --git a/lightly/utils/version_compare.py b/lightly/utils/version_compare.py index 898407cfb..05c7093f5 100644 --- a/lightly/utils/version_compare.py +++ b/lightly/utils/version_compare.py @@ -4,7 +4,7 @@ # All Rights Reserved -def version_compare(v0: str, v1: str): +def version_compare(v0: str, v1: str) -> int: """Returns 1 if version of v0 is larger than v1 and -1 otherwise Use this method to compare Python package versions and see which one is @@ -16,14 +16,14 @@ def version_compare(v0: str, v1: str): >>> version_compare('1.2.0', '1.1.2') >>> 1 """ - v0 = [int(n) for n in v0.split(".")][::-1] - v1 = [int(n) for n in v1.split(".")][::-1] - if len(v0) != 3 or len(v1) != 3: + v0_parsed = [int(n) for n in v0.split(".")][::-1] + v1_parsed = [int(n) for n in v1.split(".")][::-1] + if len(v0_parsed) != 3 or len(v1_parsed) != 3: raise ValueError( f"Length of version strings is not 3 (expected pattern `x.y.z`) but is " - f"{v0} and {v1}." + f"{v0_parsed} and {v1_parsed}." ) - pairs = list(zip(v0, v1))[::-1] + pairs = list(zip(v0_parsed, v1_parsed))[::-1] for x, y in pairs: if x < y: return -1 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..edf952772 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,237 @@ +# Global options: + +[mypy] +ignore_missing_imports = True +python_version = 3.10 +warn_unused_configs = True +strict_equality = True + +# Disallow dynamic typing +disallow_any_decorated = True +# TODO(Philipp, 09/23): Remove me! +# disallow_any_explicit = True +disallow_any_generics = True +disallow_subclassing_any = True + +# Disallow untyped definitions +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True + +# None and optional handling +no_implicit_optional = True +strict_optional = True + +# Configuring warnings +warn_unused_ignores = True +warn_no_return = True +warn_return_any = True +warn_redundant_casts = True +warn_unreachable = True + +# Print format +show_error_codes = True +show_error_context = True + +# Plugins +plugins = numpy.typing.mypy_plugin + +# Exludes +# TODO(Philipp, 09/23): Remove these one by one (start with 300 files). +exclude = (?x)( + lightly/cli/version_cli.py | + lightly/cli/crop_cli.py | + lightly/cli/serve_cli.py | + lightly/cli/embed_cli.py | + lightly/cli/lightly_cli.py | + lightly/cli/download_cli.py | + lightly/cli/config/get_config.py | + lightly/cli/train_cli.py | + lightly/cli/_cli_simclr.py | + lightly/cli/_helpers.py | + lightly/loss/ntx_ent_loss.py | + lightly/loss/vicreg_loss.py | + lightly/loss/memory_bank.py | + lightly/loss/tico_loss.py | + lightly/loss/pmsn_loss.py | + lightly/loss/swav_loss.py | + lightly/loss/negative_cosine_similarity.py | + lightly/loss/hypersphere_loss.py | + lightly/loss/msn_loss.py | + lightly/loss/dino_loss.py | + lightly/loss/sym_neg_cos_sim_loss.py | + lightly/loss/vicregl_loss.py | + lightly/loss/dcl_loss.py | + lightly/loss/regularizer/co2.py | + lightly/loss/barlow_twins_loss.py | + lightly/data/lightly_subset.py | + lightly/data/dataset.py | + lightly/data/collate.py | + lightly/data/_image.py | + lightly/data/_helpers.py | + lightly/data/_image_loaders.py | + lightly/data/_video.py | + lightly/data/_utils.py | + lightly/data/multi_view_collate.py | + lightly/embedding/_base.py | + lightly/embedding/callbacks.py | + lightly/embedding/embedding.py | + lightly/core.py | + lightly/api/api_workflow_compute_worker.py | + lightly/api/api_workflow_predictions.py | + lightly/api/download.py | + lightly/api/api_workflow_export.py | + lightly/api/api_workflow_download_dataset.py | + lightly/api/bitmask.py | + lightly/api/_version_checking.py | + lightly/api/serve.py | + lightly/api/patch.py | + lightly/api/swagger_api_client.py | + lightly/api/api_workflow_collaboration.py | + lightly/api/utils.py | + lightly/api/api_workflow_datasets.py | + lightly/api/api_workflow_selection.py | + lightly/api/swagger_rest_client.py | + lightly/api/api_workflow_datasources.py | + lightly/api/api_workflow_upload_embeddings.py | + lightly/api/api_workflow_client.py | + lightly/api/api_workflow_upload_metadata.py | + lightly/api/api_workflow_tags.py | + lightly/api/api_workflow_artifacts.py | + lightly/utils/cropping/crop_image_by_bounding_boxes.py | + lightly/utils/cropping/read_yolo_label_file.py | + lightly/utils/reordering.py | + lightly/utils/bounding_box.py | + lightly/utils/debug.py | + lightly/utils/dist.py | + lightly/utils/benchmarking/knn.py | + lightly/utils/benchmarking/linear_classifier.py | + lightly/utils/benchmarking/topk.py | + lightly/utils/benchmarking/metric_callback.py | + lightly/utils/benchmarking/benchmark_module.py | + lightly/utils/benchmarking/knn_classifier.py | + lightly/utils/benchmarking/online_linear_classifier.py | + lightly/utils/embeddings_2d.py | + lightly/models/_momentum.py | + lightly/models/batchnorm.py | + lightly/models/modules/heads.py | + lightly/models/modules/masked_autoencoder.py | + lightly/models/modules/nn_memory_bank.py | + lightly/models/modules/ijepa.py | + lightly/models/zoo.py | + lightly/models/utils.py | + lightly/models/resnet.py | + tests/cli/test_cli_version.py | + tests/cli/test_cli_magic.py | + tests/cli/test_cli_crop.py | + tests/cli/test_cli_download.py | + tests/cli/test_cli_train.py | + tests/cli/test_cli_get_lighty_config.py | + tests/cli/test_cli_embed.py | + tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py | + tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py | + tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py | + tests/loss/test_NegativeCosineSimilarity.py | + tests/loss/test_MSNLoss.py | + tests/loss/test_DINOLoss.py | + tests/loss/test_VICRegLLoss.py | + tests/loss/test_CO2Regularizer.py | + tests/loss/test_DCLLoss.py | + tests/loss/test_barlow_twins_loss.py | + tests/loss/test_SymNegCosineSimilarityLoss.py | + tests/loss/test_NTXentLoss.py | + tests/loss/test_MemoryBank.py | + tests/loss/test_TicoLoss.py | + tests/loss/test_VICRegLoss.py | + tests/loss/test_PMSNLoss.py | + tests/loss/test_HyperSphere.py | + tests/loss/test_SwaVLoss.py | + tests/core/test_Core.py | + tests/data/test_multi_view_collate.py | + tests/data/test_data_collate.py | + tests/data/test_VideoDataset.py | + tests/data/test_LightlySubset.py | + tests/data/test_LightlyDataset.py | + tests/embedding/test_callbacks.py | + tests/embedding/test_embedding.py | + tests/api/test_serve.py | + tests/api/test_swagger_rest_client.py | + tests/api/test_rest_parser.py | + tests/api/test_utils.py | + tests/api/benchmark_video_download.py | + tests/api/test_BitMask.py | + tests/api/test_patch.py | + tests/api/test_download.py | + tests/api/test_version_checking.py | + tests/api/test_swagger_api_client.py | + tests/utils/test_debug.py | + tests/utils/benchmarking/test_benchmark_module.py | + tests/utils/benchmarking/test_topk.py | + tests/utils/benchmarking/test_online_linear_classifier.py | + tests/utils/benchmarking/test_knn_classifier.py | + tests/utils/benchmarking/test_knn.py | + tests/utils/benchmarking/test_linear_classifier.py | + tests/utils/benchmarking/test_metric_callback.py | + tests/utils/test_dist.py | + tests/utils/test_io.py | + tests/models/test_ModelsSimSiam.py | + tests/models/modules/test_masked_autoencoder.py | + tests/models/test_ModelsSimCLR.py | + tests/models/test_ModelUtils.py | + tests/models/test_ModelsNNCLR.py | + tests/models/test_ModelsMoCo.py | + tests/models/test_ProjectionHeads.py | + tests/models/test_ModelsBYOL.py | + tests/conftest.py | + tests/api_workflow/test_api_workflow_selection.py | + tests/api_workflow/test_api_workflow_datasets.py | + tests/api_workflow/mocked_api_workflow_client.py | + tests/api_workflow/test_api_workflow_compute_worker.py | + tests/api_workflow/test_api_workflow_artifacts.py | + tests/api_workflow/test_api_workflow_download_dataset.py | + tests/api_workflow/utils.py | + tests/api_workflow/test_api_workflow_client.py | + tests/api_workflow/test_api_workflow_export.py | + tests/api_workflow/test_api_workflow_datasources.py | + tests/api_workflow/test_api_workflow_tags.py | + tests/api_workflow/test_api_workflow_upload_custom_metadata.py | + tests/api_workflow/test_api_workflow_upload_embeddings.py | + tests/api_workflow/test_api_workflow_collaboration.py | + tests/api_workflow/test_api_workflow_predictions.py | + tests/api_workflow/test_api_workflow.py | + # Let's not type check deprecated active learning: + lightly/active_learning | + # Let's not type deprecated models: + lightly/models/simclr.py | + lightly/models/moco.py | + lightly/models/barlowtwins.py | + lightly/models/nnclr.py | + lightly/models/simsiam.py | + lightly/models/byol.py ) + +# Ignore imports from untyped modules. +[mypy-lightly.api.*] +follow_imports = skip + +[mypy-lightly.cli.*] +follow_imports = skip + +[mypy-lightly.data.*] +follow_imports = skip + +[mypy-lightly.embedding.*] +follow_imports = skip + +[mypy-lightly.loss.*] +follow_imports = skip + +[mypy-lightly.models.*] +follow_imports = skip + +[mypy-lightly.utils.benchmarking.*] +follow_imports = skip + +# Ignore errors in auto generated code. +[mypy-lightly.openapi_generated.*] +ignore_errors = True \ No newline at end of file diff --git a/requirements/dev.txt b/requirements/dev.txt index dc74efe15..d5bc31f7b 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -22,3 +22,4 @@ torchmetrics lightning-bolts # for LARS optimizer black==23.1.0 # frozen version to avoid differences between CI and local dev machines isort==5.11.5 # frozen version to avoid differences between CI and local dev machines +mypy==1.4.1 # frozen version to avoid differences between CI and local dev machines diff --git a/tests/transforms/test_Solarize.py b/tests/transforms/test_Solarize.py index 4a1cda349..1b595caf3 100644 --- a/tests/transforms/test_Solarize.py +++ b/tests/transforms/test_Solarize.py @@ -6,7 +6,7 @@ class TestRandomSolarization(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: for w in [32, 64, 128]: for h in [32, 64, 128]: solarization = RandomSolarization(0.5) diff --git a/tests/transforms/test_byol_transform.py b/tests/transforms/test_byol_transform.py index b13950709..ecf72df48 100644 --- a/tests/transforms/test_byol_transform.py +++ b/tests/transforms/test_byol_transform.py @@ -7,14 +7,14 @@ ) -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = BYOLView1Transform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = BYOLTransform( view_1_transform=BYOLView1Transform(input_size=32), view_2_transform=BYOLView2Transform(input_size=32), diff --git a/tests/transforms/test_dino_transform.py b/tests/transforms/test_dino_transform.py index 4ca06c721..74bfea478 100644 --- a/tests/transforms/test_dino_transform.py +++ b/tests/transforms/test_dino_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = DINOViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = DINOTransform(global_crop_size=32, local_crop_size=8) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_fastsiam_transform.py b/tests/transforms/test_fastsiam_transform.py index a5f60dfdf..672cf0a41 100644 --- a/tests/transforms/test_fastsiam_transform.py +++ b/tests/transforms/test_fastsiam_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.fast_siam_transform import FastSiamTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = FastSiamTransform(num_views=3, input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_GaussianBlur.py b/tests/transforms/test_gaussian_blur.py similarity index 69% rename from tests/transforms/test_GaussianBlur.py rename to tests/transforms/test_gaussian_blur.py index a778f15e4..fae09c674 100644 --- a/tests/transforms/test_GaussianBlur.py +++ b/tests/transforms/test_gaussian_blur.py @@ -2,21 +2,21 @@ from PIL import Image -from lightly.transforms import GaussianBlur +from lightly.transforms.gaussian_blur import GaussianBlur class TestGaussianBlur(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: for w in range(1, 100): for h in range(1, 100): gaussian_blur = GaussianBlur() sample = Image.new("RGB", (w, h)) gaussian_blur(sample) - def test_raise_kernel_size_deprecation(self): + def test_raise_kernel_size_deprecation(self) -> None: gaussian_blur = GaussianBlur(kernel_size=2) self.assertWarns(DeprecationWarning) - def test_raise_scale_deprecation(self): + def test_raise_scale_deprecation(self) -> None: gaussian_blur = GaussianBlur(scale=0.1) self.assertWarns(DeprecationWarning) diff --git a/tests/transforms/test_Jigsaw.py b/tests/transforms/test_jigsaw.py similarity index 66% rename from tests/transforms/test_Jigsaw.py rename to tests/transforms/test_jigsaw.py index 964e738c5..6482ee109 100644 --- a/tests/transforms/test_Jigsaw.py +++ b/tests/transforms/test_jigsaw.py @@ -2,11 +2,11 @@ from PIL import Image -from lightly.transforms import Jigsaw +from lightly.transforms.jigsaw import Jigsaw class TestJigsaw(unittest.TestCase): - def test_on_pil_image(self): + def test_on_pil_image(self) -> None: crop = Jigsaw() sample = Image.new("RGB", (255, 255)) crop(sample) diff --git a/tests/transforms/test_location_to_NxN_grid.py b/tests/transforms/test_location_to_NxN_grid.py index 2ec1beb5b..013d4ab8f 100644 --- a/tests/transforms/test_location_to_NxN_grid.py +++ b/tests/transforms/test_location_to_NxN_grid.py @@ -3,7 +3,7 @@ import lightly.transforms.random_crop_and_flip_with_grid as test_module -def test_location_to_NxN_grid(): +def test_location_to_NxN_grid() -> None: # create a test instance of the Location class test_location = test_module.Location( left=10, diff --git a/tests/transforms/test_mae_transform.py b/tests/transforms/test_mae_transform.py index aafa11cdf..6f9b928c1 100644 --- a/tests/transforms/test_mae_transform.py +++ b/tests/transforms/test_mae_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.mae_transform import MAETransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MAETransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_moco_transform.py b/tests/transforms/test_moco_transform.py index eef1c651f..aa43a216f 100644 --- a/tests/transforms/test_moco_transform.py +++ b/tests/transforms/test_moco_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform -def test_moco_v1_multi_view_on_pil_image(): +def test_moco_v1_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV1Transform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) @@ -12,7 +12,7 @@ def test_moco_v1_multi_view_on_pil_image(): assert output[1].shape == (3, 32, 32) -def test_moco_v2_multi_view_on_pil_image(): +def test_moco_v2_multi_view_on_pil_image() -> None: multi_view_transform = MoCoV2Transform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_msn_transform.py b/tests/transforms/test_msn_transform.py index fd2030bab..4f5be7b53 100644 --- a/tests/transforms/test_msn_transform.py +++ b/tests/transforms/test_msn_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = MSNViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MSNTransform(random_size=32, focal_size=8) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_multi_view_transform.py b/tests/transforms/test_multi_view_transform.py index dddd4db3d..27806b92a 100644 --- a/tests/transforms/test_multi_view_transform.py +++ b/tests/transforms/test_multi_view_transform.py @@ -6,7 +6,7 @@ from lightly.transforms.multi_view_transform import MultiViewTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = MultiViewTransform( [ T.RandomHorizontalFlip(p=0.1), diff --git a/tests/transforms/test_pirl_transform.py b/tests/transforms/test_pirl_transform.py index 5042a1e8f..20c7c8705 100644 --- a/tests/transforms/test_pirl_transform.py +++ b/tests/transforms/test_pirl_transform.py @@ -3,7 +3,7 @@ from lightly.transforms.pirl_transform import PIRLTransform -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = PIRLTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_rotation.py b/tests/transforms/test_rotation.py index 448858fdd..065b8dd1d 100644 --- a/tests/transforms/test_rotation.py +++ b/tests/transforms/test_rotation.py @@ -1,3 +1,5 @@ +from typing import List, Tuple, Union + from PIL import Image from lightly.transforms.rotation import ( @@ -7,20 +9,21 @@ ) -def test_RandomRotate_on_pil_image(): +def test_RandomRotate_on_pil_image() -> None: random_rotate = RandomRotate() sample = Image.new("RGB", (100, 100)) random_rotate(sample) -def test_RandomRotateDegrees_on_pil_image(): - for degrees in [0, 1, 45, (0, 0), (-15, 30)]: +def test_RandomRotateDegrees_on_pil_image() -> None: + all_degrees: List[Union[float, Tuple[float, float]]] = [0, 1, 45, (0, 0), (-15, 30)] + for degrees in all_degrees: random_rotate = RandomRotateDegrees(prob=0.5, degrees=degrees) sample = Image.new("RGB", (100, 100)) random_rotate(sample) -def test_random_rotation_transform(): +def test_random_rotation_transform() -> None: transform = random_rotation_transform(rr_prob=1.0, rr_degrees=None) assert isinstance(transform, RandomRotate) transform = random_rotation_transform(rr_prob=1.0, rr_degrees=45) diff --git a/tests/transforms/test_simclr_transform.py b/tests/transforms/test_simclr_transform.py index 70fff7ab4..78a9a5cca 100644 --- a/tests/transforms/test_simclr_transform.py +++ b/tests/transforms/test_simclr_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SimCLRViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SimCLRTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_simsiam_transform.py b/tests/transforms/test_simsiam_transform.py index 2444924ec..39a88721a 100644 --- a/tests/transforms/test_simsiam_transform.py +++ b/tests/transforms/test_simsiam_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.simsiam_transform import SimSiamTransform, SimSiamViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SimSiamViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SimSiamTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_smog_transform.py b/tests/transforms/test_smog_transform.py index 49fed878b..042d46f9f 100644 --- a/tests/transforms/test_smog_transform.py +++ b/tests/transforms/test_smog_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SmoGViewTransform(crop_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SMoGTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_swav_transform.py b/tests/transforms/test_swav_transform.py index 05475ce6a..7c2cdd2c0 100644 --- a/tests/transforms/test_swav_transform.py +++ b/tests/transforms/test_swav_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = SwaVViewTransform() sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 100, 100) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = SwaVTransform(crop_sizes=(32, 8)) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_vicreg_transform.py b/tests/transforms/test_vicreg_transform.py index 5a2b0633d..06e710f25 100644 --- a/tests/transforms/test_vicreg_transform.py +++ b/tests/transforms/test_vicreg_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = VICRegViewTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 32, 32) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = VICRegTransform(input_size=32) sample = Image.new("RGB", (100, 100)) output = multi_view_transform(sample) diff --git a/tests/transforms/test_vicregl_transform.py b/tests/transforms/test_vicregl_transform.py index f4696d067..e697807c4 100644 --- a/tests/transforms/test_vicregl_transform.py +++ b/tests/transforms/test_vicregl_transform.py @@ -3,14 +3,14 @@ from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform -def test_view_on_pil_image(): +def test_view_on_pil_image() -> None: single_view_transform = VICRegLViewTransform() sample = Image.new("RGB", (100, 100)) output = single_view_transform(sample) assert output.shape == (3, 100, 100) -def test_multi_view_on_pil_image(): +def test_multi_view_on_pil_image() -> None: multi_view_transform = VICRegLTransform( global_crop_size=32, local_crop_size=8, diff --git a/tests/utils/test_scheduler.py b/tests/utils/test_scheduler.py index b21f1f1d8..00d4bf4b4 100644 --- a/tests/utils/test_scheduler.py +++ b/tests/utils/test_scheduler.py @@ -7,7 +7,7 @@ class TestScheduler(unittest.TestCase): - def test_cosine_schedule(self): + def test_cosine_schedule(self) -> None: self.assertAlmostEqual(cosine_schedule(1, 10, 0.99, 1.0), 0.99030154, 6) self.assertAlmostEqual(cosine_schedule(95, 100, 0.7, 2.0), 1.99477063, 6) self.assertAlmostEqual(cosine_schedule(0, 1, 0.996, 1.0), 1.0, 6) @@ -23,7 +23,7 @@ def test_cosine_schedule(self): ): cosine_schedule(11, 10, 0.0, 1.0) - def test_CosineWarmupScheduler(self): + def test_CosineWarmupScheduler(self) -> None: model = nn.Linear(10, 1) optimizer = torch.optim.SGD( model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0 @@ -63,7 +63,7 @@ def test_CosineWarmupScheduler(self): ): scheduler.step() - def test_CosineWarmupScheduler__warmup(self): + def test_CosineWarmupScheduler__warmup(self) -> None: model = nn.Linear(10, 1) optimizer = torch.optim.SGD( model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0 diff --git a/tests/utils/test_version_compare.py b/tests/utils/test_version_compare.py index ce39dbb6a..40d516ea0 100644 --- a/tests/utils/test_version_compare.py +++ b/tests/utils/test_version_compare.py @@ -4,7 +4,7 @@ class TestVersionCompare(unittest.TestCase): - def test_valid_versions(self): + def test_valid_versions(self) -> None: # general test of smaller than version numbers self.assertEqual(version_compare.version_compare("0.1.4", "1.2.0"), -1) self.assertEqual(version_compare.version_compare("1.1.0", "1.2.0"), -1) @@ -16,7 +16,7 @@ def test_valid_versions(self): # test equal self.assertEqual(version_compare.version_compare("1.2.0", "1.2.0"), 0) - def test_invalid_versions(self): + def test_invalid_versions(self) -> None: with self.assertRaises(ValueError): version_compare.version_compare("1.2", "1.1.0")