Skip to content

Commit

Permalink
correcting typing
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielBG0 committed Jul 4, 2024
1 parent 2ee6e8c commit e9e46fb
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 129 deletions.
2 changes: 1 addition & 1 deletion minerva/analysis/metrics/pixel_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class PixelAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
def __init__(self, dist_sync_on_step: bool = False):
"""
Initializes a PixelAccuracy metric object.
Expand Down
6 changes: 3 additions & 3 deletions minerva/data/datasets/supervised_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -45,7 +45,7 @@ class SupervisedReconstructionDataset(SimpleDataset):
```
"""

def __init__(self, readers: List[_Reader], transforms: _Transform | None = None):
def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = None):
"""A simple dataset class for supervised reconstruction tasks.
Parameters
Expand All @@ -54,7 +54,7 @@ def __init__(self, readers: List[_Reader], transforms: _Transform | None = None)
List of data readers. It must contain exactly 2 readers.
The first reader for the input data and the second reader for the
target data.
transforms: _Transform | None
transforms: Optional[_Transform]
Optional data transformation pipeline.
Raises
Expand Down
12 changes: 6 additions & 6 deletions minerva/data/readers/csv_reader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Union
from typing import List, Optional, Tuple, Union

import pandas as pd

from minerva.data.readers.tabular_reader import TabularReader


class CSVReader(TabularReader):
def __init__(
self,
path: str,
columns_to_select: Union[str, list[str]],
cast_to: str = None,
data_shape: tuple[int, ...] = None,
columns_to_select: Union[str, List[str]],
cast_to: Optional[str] = None,
data_shape: Optional[Tuple[int, ...]] = None,
):
df = pd.read_csv(path)
super().__init__(df, columns_to_select, cast_to, data_shape)


8 changes: 4 additions & 4 deletions minerva/data/readers/patched_array_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike
Expand All @@ -22,10 +22,10 @@ def __init__(
self,
data: ArrayLike,
data_shape: Tuple[int, ...],
stride: Tuple[int, ...] = None,
pad_width: Tuple[Tuple[int, int], ...] = None,
stride: Optional[Tuple[int, ...]] = None,
pad_width: Optional[Tuple[Tuple[int, int], ...]] = None,
pad_mode: str = "constant",
pad_kwargs: dict = None,
pad_kwargs: Optional[Dict] = None,
):
"""Reads data from a NumPy array and generates patches from it.
Expand Down
16 changes: 8 additions & 8 deletions minerva/data/readers/tabular_reader.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import re
from pathlib import Path
from typing import Union
from typing import List, Optional, Tuple, Union

import numpy as np
import re
import pandas as pd

from minerva.data.readers.reader import _Reader


class TabularReader(_Reader):
def __init__(
self,
df: pd.DataFrame,
columns_to_select: Union[str, list[str]],
cast_to: str = None,
data_shape: tuple[int, ...] = None,
columns_to_select: Union[str, List[str]],
cast_to: Optional[str] = None,
data_shape: Optional[Tuple[int, ...]] = None,
):
"""Reader to select columns from a DataFrame and return them as a NumPy
array. The DataFrame is indexed by the row number. Each row of the
Expand Down Expand Up @@ -64,9 +66,7 @@ def __getitem__(self, index: int) -> np.ndarray:
# Filter valid columns based on columns_to_select list
valid_columns = []
for pattern in self.columns_to_select:
valid_columns.extend(
[col for col in columns if re.match(pattern, col)]
)
valid_columns.extend([col for col in columns if re.match(pattern, col)])

# Select the elements and return
row = self.df.iloc[index][valid_columns]
Expand Down
28 changes: 14 additions & 14 deletions minerva/models/nets/image/setr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import lightning as L
import torch
Expand Down Expand Up @@ -222,7 +222,7 @@ class _SetR_PUP(nn.Module):

def __init__(
self,
image_size: int | tuple[int, int],
image_size: Union[int, Tuple[int, int]],
patch_size: int,
num_layers: int,
num_heads: int,
Expand All @@ -241,14 +241,14 @@ def __init__(
conv_act: nn.Module,
align_corners: bool,
aux_output: bool = False,
aux_output_layers: list[int] | None = None,
aux_output_layers: Optional[List[int]] = None,
):
"""
Initializes the SETR PUP model.
Parameters
----------
image_size : int or tuple[int, int]
image_size : int or Tuple[int, int]
The size of the input image.
patch_size : int
The size of each patch in the input image.
Expand Down Expand Up @@ -388,7 +388,7 @@ class SETR_PUP(L.LightningModule):

def __init__(
self,
image_size: int | tuple[int, int] = 512,
image_size: Union[int, Tuple[int, int]] = 512,
patch_size: int = 16,
num_layers: int = 24,
num_heads: int = 16,
Expand All @@ -411,15 +411,15 @@ def __init__(
val_metrics: Optional[Dict[str, Metric]] = None,
test_metrics: Optional[Dict[str, Metric]] = None,
aux_output: bool = True,
aux_output_layers: list[int] | None = [9, 14, 19],
aux_weights: list[float] = [0.3, 0.3, 0.3],
aux_output_layers: Optional[List[int]] = [9, 14, 19],
aux_weights: List[float] = [0.3, 0.3, 0.3],
):
"""
Initializes the SetR model.
Parameters
----------
image_size : int or tuple[int, int]
image_size : int or Tuple[int, int]
The input image size. Defaults to 512.
patch_size : int
The size of each patch. Defaults to 16.
Expand Down Expand Up @@ -465,9 +465,9 @@ def __init__(
The metrics to be used for testing evaluation. Defaults to None.
aux_output : bool
Whether to include auxiliary output heads in the model. Defaults to True.
aux_output_layers : list[int] | None
aux_output_layers : List[int], optional
The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19].
aux_weights : list[float]
aux_weights : List[float]
The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3].
"""
Expand Down Expand Up @@ -536,9 +536,9 @@ def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str)

def _loss_func(
self,
y_hat: (
torch.Tensor | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
),
y_hat: Union[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
],
y: torch.Tensor,
) -> torch.Tensor:
"""Calculate the loss between the output and the input data.
Expand Down Expand Up @@ -626,7 +626,7 @@ def test_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, "test")

def predict_step(
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None
self, batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None
):
x, _ = batch
return self.model(x)[0]
Expand Down
14 changes: 7 additions & 7 deletions minerva/models/nets/image/vit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from collections import OrderedDict
from functools import partial
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple, Union

import lightning as L
import torch
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(
dropout: float,
attention_dropout: float,
aux_output: bool = False,
aux_output_layers: List[int] | None = None,
aux_output_layers: Optional[List[int]] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
Expand Down Expand Up @@ -81,7 +81,7 @@ class _VisionTransformerBackbone(nn.Module):

def __init__(
self,
image_size: int | tuple[int, int],
image_size: Union[int, Tuple[int, int]],
patch_size: int,
num_layers: int,
num_heads: int,
Expand All @@ -91,7 +91,7 @@ def __init__(
attention_dropout: float = 0.0,
num_classes: int = 1000,
aux_output: bool = False,
aux_output_layers: List[int] | None = None,
aux_output_layers: Optional[List[int]] = None,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
Expand All @@ -100,7 +100,7 @@ def __init__(
Parameters
----------
image_size : int or tuple[int, int]
image_size : int or Tuple[int, int]
The size of the input image. If an int is provided, it is assumed
to be a square image. If a tuple of ints is provided, it represents the height and width of the image.
patch_size : int
Expand Down Expand Up @@ -234,14 +234,14 @@ def __init__(
if self.conv_proj.conv_last.bias is not None:
nn.init.zeros_(self.conv_proj.conv_last.bias)

def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]:
def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""Process the input tensor and return the reshaped tensor and dimensions.
Args:
x (torch.Tensor): The input tensor.
Returns:
tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns.
Tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns.
"""
n, c, h, w = x.shape
p = self.patch_size
Expand Down
14 changes: 9 additions & 5 deletions minerva/models/nets/time_series/cnns.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Union

import torch
from torchmetrics import Accuracy
Expand All @@ -7,7 +7,7 @@


class ZeroPadder2D(torch.nn.Module):
def __init__(self, pad_at: List[int], padding_size: int):
def __init__(self, pad_at: Tuple[int], padding_size: int):
super().__init__()
self.pad_at = pad_at
self.padding_size = padding_size
Expand Down Expand Up @@ -56,7 +56,9 @@ def __init__(
test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)},
)

def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module:
def _create_backbone(
self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]]
) -> torch.nn.Module:
return torch.nn.Sequential(
# First 2D convolutional layer
torch.nn.Conv2d(
Expand Down Expand Up @@ -120,7 +122,7 @@ def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module:
class CNN_HaEtAl_2D(SimpleSupervisedModel):
def __init__(
self,
pad_at: List[int] = (3,),
pad_at: Tuple[int] = (3,),
input_shape: Tuple[int, int, int] = (1, 6, 60),
num_classes: int = 6,
learning_rate: float = 1e-3,
Expand All @@ -144,7 +146,9 @@ def __init__(
test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)},
)

def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module:
def _create_backbone(
self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]]
) -> torch.nn.Module:
first_kernel_size = 4
return torch.nn.Sequential(
# Add padding
Expand Down
4 changes: 2 additions & 2 deletions minerva/models/nets/time_series/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import torch
from torchmetrics import Accuracy
Expand All @@ -7,7 +7,7 @@


class ConvolutionalBlock(torch.nn.Module):
def __init__(self, in_channels: int, activation_cls: torch.nn.Module = None):
def __init__(self, in_channels: int, activation_cls: torch.nn.Module):
super().__init__()
self.in_channels = in_channels
self.activation_cls = activation_cls
Expand Down
Loading

0 comments on commit e9e46fb

Please sign in to comment.