Skip to content

Commit

Permalink
Add mypy and partial typing (#1382)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippmwirth authored Sep 11, 2023
1 parent 0f20961 commit 80cd4e0
Show file tree
Hide file tree
Showing 55 changed files with 457 additions and 168 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/test_code_format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ jobs:
test:
name: Check
runs-on: ubuntu-latest

steps:
- name: Checkout Code
uses: actions/checkout@v3
Expand All @@ -21,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.7"
- uses: actions/cache@v2
with:
path: ${{ env.pythonLocation }}
Expand All @@ -30,5 +29,7 @@ jobs:
run: pip install -e '.[all]'
- name: Run Format Check
run: |
export LIGHTLY_SERVER_LOCATION="localhost:-1"
make format-check
- name: Run Type Check
run: |
make type-check
11 changes: 9 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,15 @@ test:
test-fast:
pytest tests

# run format checks and tests
all-checks: format-check test
## check typing
type-check:
mypy lightly tests

## run format checks
static-checks: format-check type-check

## run format checks and tests
all-checks: static-checks test

## build source and wheel package
dist: clean
Expand Down
2 changes: 2 additions & 0 deletions lightly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
__name__ = "lightly"
__version__ = "1.4.17"


import os

# see if torchvision vision transformer is available
Expand All @@ -91,6 +92,7 @@
):
_torchvision_vit_available = False


if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False":
os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True"
import multiprocessing
Expand Down
2 changes: 1 addition & 1 deletion lightly/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
RETRY_MAX_RETRIES = 5


def retry(func, *args, **kwargs):
def retry(func, *args, **kwargs): # type: ignore
"""Repeats a function until it completes successfully or fails too often.
Args:
Expand Down
1 change: 1 addition & 0 deletions lightly/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved


from lightly.embedding._base import BaseEmbedding
from lightly.embedding.embedding import SelfSupervisedEmbedding
1 change: 0 additions & 1 deletion lightly/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

from lightly.loss.barlow_twins_loss import BarlowTwinsLoss
from lightly.loss.dcl_loss import DCLLoss, DCLWLoss
from lightly.loss.dino_loss import DINOLoss
Expand Down
12 changes: 7 additions & 5 deletions lightly/transforms/byol_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
normalize: Union[None, dict] = IMAGENET_NORMALIZE,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
Expand Down Expand Up @@ -67,7 +67,8 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
return self.transform(image)
transformed: Tensor = self.transform(image)
return transformed


class BYOLView2Transform:
Expand All @@ -90,7 +91,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
normalize: Union[None, dict] = IMAGENET_NORMALIZE,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
Expand Down Expand Up @@ -126,7 +127,8 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
return self.transform(image)
transformed: Tensor = self.transform(image)
return transformed


class BYOLTransform(MultiViewTransform):
Expand Down
9 changes: 5 additions & 4 deletions lightly/transforms/dino_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import PIL
import torchvision.transforms as T
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
kernel_scale: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
solarization_prob: float = 0.2,
normalize: Union[None, dict] = IMAGENET_NORMALIZE,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
# first global crop
global_transform_0 = DINOViewTransform(
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
kernel_scale: Optional[float] = None,
sigmas: Tuple[float, float] = (0.1, 2),
solarization_prob: float = 0.2,
normalize: Union[None, dict] = IMAGENET_NORMALIZE,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transform = [
T.RandomResizedCrop(
Expand Down Expand Up @@ -262,4 +262,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
return self.transform(image)
transformed: Tensor = self.transform(image)
return transformed
4 changes: 2 additions & 2 deletions lightly/transforms/fast_siam_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.simsiam_transform import SimSiamViewTransform
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
normalize: Union[None, dict] = IMAGENET_NORMALIZE,
normalize: Union[None, Dict[str, List[float]]] = IMAGENET_NORMALIZE,
):
transforms = [
SimSiamViewTransform(
Expand Down
3 changes: 2 additions & 1 deletion lightly/transforms/gaussian_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from PIL import ImageFilter
from PIL.Image import Image


class GaussianBlur:
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self.prob = prob
self.sigmas = sigmas

def __call__(self, sample):
def __call__(self, sample: Image) -> Image:
"""Blurs the image with a given probability.
Args:
Expand Down
7 changes: 4 additions & 3 deletions lightly/transforms/ijepa_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Dict, List, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 224,
min_scale: float = 0.2,
normalize: dict = IMAGENET_NORMALIZE,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
transforms = [
T.RandomResizedCrop(
Expand All @@ -55,4 +55,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
return self.transform(image)
transformed: Tensor = self.transform(image)
return transformed
4 changes: 2 additions & 2 deletions lightly/transforms/image_grid_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import List, Sequence, Union

import torchvision.transforms as T
from PIL.Image import Image
Expand All @@ -19,7 +19,7 @@ class ImageGridTransform:
grids.
"""

def __init__(self, transforms):
def __init__(self, transforms: Sequence[T.Compose]):
self.transforms = transforms

def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]:
Expand Down
28 changes: 19 additions & 9 deletions lightly/transforms/jigsaw.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) 2021. Lightly AG and its affiliates.
# All Rights Reserved

from typing import List

import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from PIL import Image as Image
from PIL.Image import Image as PILImage
from torch import Tensor
from torchvision import transforms as T


class Jigsaw(object):
Expand Down Expand Up @@ -34,7 +38,11 @@ class Jigsaw(object):
"""

def __init__(
self, n_grid=3, img_size=255, crop_size=64, transform=transforms.ToTensor()
self,
n_grid: int = 3,
img_size: int = 255,
crop_size: int = 64,
transform: T.Compose = T.ToTensor(),
):
self.n_grid = n_grid
self.img_size = img_size
Expand All @@ -47,7 +55,7 @@ def __init__(
self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,))

def __call__(self, img):
def __call__(self, img: PILImage) -> Tensor:
"""Performs the Jigsaw augmentation
Args:
img:
Expand All @@ -59,7 +67,7 @@ def __call__(self, img):
r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
img = np.asarray(img, np.uint8)
crops = []
crops: List[PILImage] = []
for i in range(self.n_grid * self.n_grid):
crops.append(
img[
Expand All @@ -68,7 +76,9 @@ def __call__(self, img):
:,
]
)
crops = [Image.fromarray(crop) for crop in crops]
crops = torch.stack([self.transform(crop) for crop in crops])
crops = crops[np.random.permutation(self.n_grid**2)]
return crops
crop_images = [Image.fromarray(crop) for crop in crops]
crop_tensors: Tensor = torch.stack(
[self.transform(crop) for crop in crop_images]
)
permutation: List[int] = np.random.permutation(self.n_grid**2).tolist()
return crop_tensors[permutation]
5 changes: 2 additions & 3 deletions lightly/transforms/mae_transform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List, Tuple, Union
from typing import Dict, List, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor

from lightly.transforms.multi_view_transform import MultiViewTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE


Expand Down Expand Up @@ -37,7 +36,7 @@ def __init__(
self,
input_size: Union[int, Tuple[int, int]] = 224,
min_scale: float = 0.2,
normalize: dict = IMAGENET_NORMALIZE,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
transforms = [
T.RandomResizedCrop(
Expand Down
4 changes: 2 additions & 2 deletions lightly/transforms/moco_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
hf_prob: float = 0.5,
rr_prob: float = 0.0,
rr_degrees: Union[None, float, Tuple[float, float]] = None,
normalize: dict = IMAGENET_NORMALIZE,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
super().__init__(
input_size=input_size,
Expand Down
9 changes: 5 additions & 4 deletions lightly/transforms/msn_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(
random_gray_scale: float = 0.2,
hf_prob: float = 0.5,
vf_prob: float = 0.0,
normalize: dict = IMAGENET_NORMALIZE,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
random_view_transform = MSNViewTransform(
crop_size=random_size,
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
random_gray_scale: float = 0.2,
hf_prob: float = 0.5,
vf_prob: float = 0.0,
normalize: dict = IMAGENET_NORMALIZE,
normalize: Dict[str, List[float]] = IMAGENET_NORMALIZE,
):
color_jitter = T.ColorJitter(
brightness=cj_strength * cj_bright,
Expand Down Expand Up @@ -183,4 +183,5 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor:
The transformed image.
"""
return self.transform(image)
transformed: Tensor = self.transform(image)
return transformed
10 changes: 5 additions & 5 deletions lightly/transforms/multi_crop_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class MultiCropTranform(MultiViewTransform):

def __init__(
self,
crop_sizes: Tuple[int],
crop_counts: Tuple[int],
crop_min_scales: Tuple[float],
crop_max_scales: Tuple[float],
transforms,
crop_sizes: Tuple[int, ...],
crop_counts: Tuple[int, ...],
crop_min_scales: Tuple[float, ...],
crop_max_scales: Tuple[float, ...],
transforms: T.Compose,
):
if len(crop_sizes) != len(crop_counts):
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions lightly/transforms/multi_view_transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Union
from typing import List, Sequence, Union

from PIL.Image import Image
from torch import Tensor
from torchvision import transforms as T


class MultiViewTransform:
Expand All @@ -13,7 +14,7 @@ class MultiViewTransform:
"""

def __init__(self, transforms):
def __init__(self, transforms: Sequence[T.Compose]):
self.transforms = transforms

def __call__(self, image: Union[Tensor, Image]) -> Union[List[Tensor], List[Image]]:
Expand Down
Loading

0 comments on commit 80cd4e0

Please sign in to comment.