diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index acc9452..10948db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,4 +50,5 @@ repos: - torch - traitlets - timm + - kornia args: [ --disallow-any-generics, --disallow-untyped-defs, --disable-error-code=import-untyped] diff --git a/README.md b/README.md index fea5e96..30c78c5 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,17 @@ For caching some imports are only required, these have to be manually installed - pyarrow >= 6.0.0 (Read parquet files) - annotated-types >= 0.6.0 +### Model + +There is support for using timm models. To be able to do so the user must manually install timm. +- timm >= 0.9.16 + +### Augmentation + +There is also implementations of augmentations that are not in commonly used packages. Most of these are for time series data but there are implmenetations for CutMix and MixUp for images that can be used in the pipeline. To be able to use these the user must manually install kornia. + +- kornia >= 0.7.2 + ## Documentation Documentation is generated using [Sphinx](https://www.sphinx-doc.org/en/master/). diff --git a/epochalyst/pipeline/model/training/augmentation/__init__.py b/epochalyst/pipeline/model/training/augmentation/__init__.py new file mode 100644 index 0000000..758ab3e --- /dev/null +++ b/epochalyst/pipeline/model/training/augmentation/__init__.py @@ -0,0 +1 @@ +"""Module containing implementation for augmentations.""" diff --git a/epochalyst/pipeline/model/training/augmentation/image_augmentations.py b/epochalyst/pipeline/model/training/augmentation/image_augmentations.py new file mode 100644 index 0000000..bf042f7 --- /dev/null +++ b/epochalyst/pipeline/model/training/augmentation/image_augmentations.py @@ -0,0 +1,103 @@ +"""Contains implementation of several image augmentations using PyTorch.""" + +from dataclasses import dataclass, field +from typing import Any + +import torch + + +def get_kornia_mix() -> Any: # noqa: ANN401 + """Return kornia mix.""" + try: + import kornia + + except ImportError: + raise ImportError( + "If you want to use this augmentation you must install kornia", + ) from None + + else: + return kornia.augmentation._2d.mix # noqa: SLF001 + + +@dataclass +class CutMix: + """2D CutMix implementation for spectrogram data augmentation. + + :param cut_size: The size of the cut + :param same_on_batch: Apply the same transformation across the batch + :param p: The probability of applying the filter + """ + + cut_size: tuple[float, float] = field(default=(0.0, 1.0)) + same_on_batch: bool = False + p: float = 0.5 + + def __post_init__(self) -> None: + """Check if the filter type is valid.""" + self.cutmix = get_kornia_mix().cutmix.RandomCutMixV2( + p=self.p, + cut_size=self.cut_size, + same_on_batch=self.same_on_batch, + data_keys=["input", "class"], + ) + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Randomly patch the input with another sample. + + :param x: Input images. (N,C,W,H) + :param y: Input labels. (N,C) + """ + dummy_labels = torch.arange(x.size(0)) + augmented_x, augmentation_info = self.cutmix(x, dummy_labels) + augmentation_info = augmentation_info[0] + + y = y.float() + y_result = y.clone() + for i in range(augmentation_info.shape[0]): + y_result[i] = y[i] * (1 - augmentation_info[i, 2]) + y[int(augmentation_info[i, 1])] * augmentation_info[i, 2] + + return augmented_x, y_result + + +@dataclass +class MixUp: + """2D MixUp implementation for spectrogram data augmentation. + + :param lambda_val: The range of the mixup coefficient + :param same_on_batch: Apply the same transformation across the batch + :param p: The probability of applying the filter + """ + + lambda_val: tuple[float, float] = field(default=(0.0, 1.0)) + same_on_batch: bool = False + p: float = 0.5 + + def __post_init__(self) -> None: + """Check if the filter type is valid.""" + self.mixup = get_kornia_mix().mixup.RandomMixUpV2( + p=self.p, + lambda_val=self.lambda_val, + same_on_batch=self.same_on_batch, + data_keys=["input", "class"], + ) + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Randomly patch the input with another sample.""" + dummy_labels = torch.arange(x.size(0)) + augmented_x, augmentation_info = self.mixup(x, dummy_labels) + + y = y.float() + y_result = y.clone() + for i in range(augmentation_info.shape[0]): + y_result[i] = y[i] * (1 - augmentation_info[i, 2]) + y[int(augmentation_info[i, 1])] * augmentation_info[i, 2] + + return augmented_x, y_result diff --git a/epochalyst/pipeline/model/training/augmentation/time_series_augmentations.py b/epochalyst/pipeline/model/training/augmentation/time_series_augmentations.py new file mode 100644 index 0000000..6195c15 --- /dev/null +++ b/epochalyst/pipeline/model/training/augmentation/time_series_augmentations.py @@ -0,0 +1,206 @@ +"""Contains implementation of several custom time series augmentations using PyTorch.""" + +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class CutMix1D(torch.nn.Module): + """CutMix augmentation for 1D signals. + + Randomly select a percentage between 'low' and 'high' to preserve on the left side of the signal. + The right side will be replaced by the corresponding range from another sample from the batch. + The labels become the weighted average of the mixed signals where weights are the mix ratios. + """ + + p: float = 0.5 + low: float = 0 + high: float = 1 + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Appply CutMix to the batch of 1D signal. + + :param x: Input features. (N,C,L) + :param y: Input labels. (N,C) + :return: The augmented features and labels + """ + indices = torch.arange(x.shape[0], device=x.device, dtype=torch.int) + shuffled_indices = torch.randperm(indices.shape[0]) + + low_len = int(self.low * x.shape[-1]) + high_len = int(self.high * x.shape[-1]) + cutoff_indices = torch.randint( + low_len, + high_len, + (x.shape[-1],), + device=x.device, + dtype=torch.int, + ) + cutoff_rates = cutoff_indices.float() / x.shape[-1] + + augmented_x = x.clone() + augmented_y = y.clone().float() + for i in range(x.shape[0]): + if torch.rand(1) < self.p: + augmented_x[i, :, cutoff_indices[i] :] = x[ + shuffled_indices[i], + :, + cutoff_indices[i] :, + ] + augmented_y[i] = y[i] * cutoff_rates[i] + y[shuffled_indices[i]] * (1 - cutoff_rates[i]) + return augmented_x, augmented_y + + +@dataclass +class MixUp1D(torch.nn.Module): + """MixUp augmentation for 1D signals. + + Randomly takes the weighted average of 2 samples and their labels with random weights. + """ + + p: float = 0.5 + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Appply MixUp to the batch of 1D signal. + + :param x: Input features. (N,C,L)|(N,L) + :param y: Input labels. (N,C) + :return: The augmented features and labels + """ + indices = torch.arange(x.shape[0], device=x.device, dtype=torch.int) + shuffled_indices = torch.randperm(indices.shape[0]) + + augmented_x = x.clone() + augmented_y = y.clone().float() + for i in range(x.shape[0]): + if torch.rand(1) < self.p: + lambda_ = torch.rand(1, device=x.device) + augmented_x[i] = lambda_ * x[i] + (1 - lambda_) * x[shuffled_indices[i]] + augmented_y[i] = lambda_ * y[i] + (1 - lambda_) * y[shuffled_indices[i]] + return augmented_x, augmented_y + + +@dataclass +class Mirror1D(torch.nn.Module): + """Mirror augmentation for 1D signals. + + Mirrors the signal around its mean in the horizontal(time) axis. + """ + + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the augmentation to the input signal. + + :param x: Input features. (N,C,L)|(N,L) + :return: Augmented features. (N,C,L)|(N,L) + """ + augmented_x = x.clone() + for i in range(x.shape[0]): + if torch.rand(1) < self.p: + augmented_x[i] = -1 * x[i] + 2 * x[i].mean(dim=-1).unsqueeze(-1) + return augmented_x + + +@dataclass +class RandomAmplitudeShift(torch.nn.Module): + """Randomly scale the amplitude of all the frequencies in the signal.""" + + low: float = 0.5 + high: float = 1.5 + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the augmentation to the input signal. + + :param x: Input features. (N,C,L)|(N,L) + :return: Augmented features. (N,C,L)|(N,L) + """ + if torch.rand(1) < self.p: + # Take the rfft of the input tensor + x_freq = torch.fft.rfft(x, dim=-1) + # Create a random tensor of scaler in the range [low,high] + random_amplitude = torch.rand(*x_freq.shape, device=x.device, dtype=x.dtype) * (self.high - self.low) + self.low + # Multiply the rfft with the random amplitude + x_freq = x_freq * random_amplitude + # Take the irfft of the result + return torch.fft.irfft(x_freq, dim=-1) + return x + + +@dataclass +class RandomPhaseShift(torch.nn.Module): + """Randomly shift the phase of all the frequencies in the signal.""" + + shift_limit: float = 0.25 + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply Random phase shift to each frequency of the fft of the input signal. + + :param x: Input features. (N,C,L)|(N,L)|(L) + :return: augmented features. (N,C,L)|(N,L)|(L) + """ + if torch.rand(1) < self.p: + # Take the rfft of the input tensor + x_freq = torch.fft.rfft(x, dim=-1) + # Create a random tensor of complex numbers each with a random phase but with magnitude of 1 + random_phase = torch.rand(*x_freq.shape, device=x.device, dtype=x.dtype) * 2 * np.pi * self.shift_limit + random_phase = torch.cos(random_phase) + 1j * torch.sin(random_phase) + # Multiply the rfft with the random phase + x_freq = x_freq * random_phase + # Take the irfft of the result + return torch.fft.irfft(x_freq, dim=-1) + return x + + +@dataclass +class Reverse1D(torch.nn.Module): + """Reverse augmentation for 1D signals.""" + + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the augmentation to the input signal. + + :param x: Input features. (N,C,L)|(N,L) + :return: Augmented features (N,C,L)|(N,L) + """ + augmented_x = x.clone() + for i in range(x.shape[0]): + if torch.rand(1) < self.p: + augmented_x[i] = torch.flip(x[i], [-1]) + return augmented_x + + +@dataclass +class SubstractChannels(torch.nn.Module): + """Randomly substract other channels from the current one.""" + + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply substracting other channels to the input signal. + + :param x: Input features. (N,C,L) + :return: Augmented features. (N,C,L) + """ + if x.shape[1] == 1: + raise ValueError( + "Sequence only has 1 channel. No channels to subtract from each other", + ) + if torch.rand(1) < self.p: + length = x.shape[1] - 1 + total = x.sum(dim=1) / length + x = x - total.unsqueeze(1) + (x / length) + return x diff --git a/epochalyst/pipeline/model/training/augmentation/utils.py b/epochalyst/pipeline/model/training/augmentation/utils.py new file mode 100644 index 0000000..af2c644 --- /dev/null +++ b/epochalyst/pipeline/model/training/augmentation/utils.py @@ -0,0 +1,116 @@ +"""Module providing utility classes for applying augmentations to data. + +Classes: +- CustomApplyOne: A custom sequential class for applying a single augmentation fro a selection based on their probabilities. +- CustomSequential: A custom sequential class for applying augmentations sequentially. +- NoOp: A class representing a no-operation augmentation. +""" + +from dataclasses import dataclass, field +from typing import Any + +import torch + + +@dataclass +class CustomApplyOne: + """Custom sequential class for augmentations.""" + + probabilities_tensor: torch.Tensor = field(init=False) + x_transforms: list[Any] = field(default_factory=list) + xy_transforms: list[Any] = field(default_factory=list) + + def __post_init__(self) -> None: + """Post initialization function of CustomApplyOne.""" + self.probabilities = [] + if self.x_transforms is not None: + for transform in self.x_transforms: + self.probabilities.append(transform.p) + if self.xy_transforms is not None: + for transform in self.xy_transforms: + self.probabilities.append(transform.p) + + # Make tensor from probs + self.probabilities_tensor = torch.tensor(self.probabilities) + # Ensure sum is 1 + self.probabilities_tensor /= self.probabilities_tensor.sum() + self.all_transforms = self.x_transforms + self.xy_transforms + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply the augmentations sequentially. + + :param x: Input features + :param y: Input labels + :return: Augmented features and labels + """ + transform = self.all_transforms[ + int( + torch.multinomial( + self.probabilities_tensor, + 1, + replacement=False, + ).item(), + ) + ] + if transform in self.x_transforms: + x = transform(x) + if transform in self.xy_transforms: + x, y = transform(x, y) + return x, y + + +@dataclass +class CustomSequential: + """Custom sequential class for augmentations. + + This class applies augmentations sequentially without probabilities. + """ + + x_transforms: list[Any] = field(default_factory=list) + xy_transforms: list[Any] = field(default_factory=list) + + def __call__( + self, + x: torch.Tensor, + y: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply the augmentations sequentially. + + :param x: Input features. + :param y: input labels. + :return: Augmented features and labels. + """ + if self.x_transforms is not None: + for transform in self.x_transforms: + x = transform(x) + if self.xy_transforms is not None: + for transform in self.xy_transforms: + x, y = transform(x, y) + return x, y + + +@dataclass +class NoOp(torch.nn.Module): + """CutMix augmentation for 1D signals. + + This class represents a no-operation augmentation. + """ + + p: float = 0.5 + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """Apply the augmentation to the input signal. + + Args: + ---- + x (torch.Tensor): The input signal tensor. + + Returns: + ------- + torch.Tensor: The augmented input signal tensor. + """ + return x diff --git a/epochalyst/pipeline/model/training/models/timm.py b/epochalyst/pipeline/model/training/models/timm.py index b2723f7..539367b 100644 --- a/epochalyst/pipeline/model/training/models/timm.py +++ b/epochalyst/pipeline/model/training/models/timm.py @@ -1,17 +1,10 @@ -"""Timm model for 2D spectrogram classification.""" +"""Timm model for 2D image classification.""" import torch from torch import nn class Timm(nn.Module): - """Timm model for 2D spectrogram classification.. - - Input: - X: (n_samples, n_channel, n_width, n_height) - Y: (n_samples) - - Output: - out: (n_samples) + """Timm model for 2D image classification. :param in_channels: Number of input channels :param out_channels: Number of output channels diff --git a/requirements.txt b/requirements.txt index 5e7d2f3..07f8b89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,4 @@ tqdm==4.66.2 typing_extensions==4.10.0 urllib3==2.2.1 zipp==3.17.0 +kornia==0.7.2 diff --git a/tests/pipeline/model/training/augmentation/test_image_augmentations.py b/tests/pipeline/model/training/augmentation/test_image_augmentations.py new file mode 100644 index 0000000..fb02bbb --- /dev/null +++ b/tests/pipeline/model/training/augmentation/test_image_augmentations.py @@ -0,0 +1,51 @@ +from epochalyst.pipeline.model.training.augmentation import image_augmentations +import torch + + +class TestImageAugmentations: + def test_cutmix(self): + # Create a CutMix instance + cutmix = image_augmentations.CutMix(p=1.0) + + # Create dummy input and labels + x = torch.cat( + [torch.ones(16, 1, 100, 100), torch.zeros(16, 1, 100, 100)], dim=0 + ) + # Multiclass labels + y = torch.cat([torch.ones(16, 2), torch.zeros(16, 2)], dim=0) + # Apply CutMix augmentation + augmented_x, augmented_y = cutmix(x, y) + + # Assert the output shapes are correct + assert augmented_x.shape == x.shape + assert augmented_y.shape == y.shape + + # Because the images are all ones and zeros the mean of the pixels should be equal to the labels after being transformed + assert torch.allclose(augmented_x.mean(dim=-1).mean(dim=-1), augmented_y) + + cutmix = image_augmentations.CutMix(p=0) + augmented_x, augmented_y = cutmix(x, y) + + assert torch.all(augmented_x == x) & torch.all(augmented_y == y) + + def test_mixup(self): + mixup = image_augmentations.MixUp(p=1.0) + # Create dummy input and labels + x = torch.cat( + [torch.ones(16, 1, 100, 100), torch.zeros(16, 1, 100, 100)], dim=0 + ) + # Multiclass labels + y = torch.cat([torch.ones(16, 2), torch.zeros(16, 2)], dim=0) + # Apply CutMix augmentation + augmented_x, augmented_y = mixup(x, y) + # Assert the output shapes are correct + assert augmented_x.shape == x.shape + assert augmented_y.shape == y.shape + + # Because the images are all ones and zeros the mean of the pixels should be equal to the labels after being transformed + assert torch.allclose(augmented_x.mean(dim=-1).mean(dim=-1), augmented_y) + + mixup = image_augmentations.MixUp(p=0) + augmented_x, augmented_y = mixup(x, y) + + assert torch.all(augmented_x == x) & torch.all(augmented_y == y) diff --git a/tests/pipeline/model/training/augmentation/test_time_series_augmentations.py b/tests/pipeline/model/training/augmentation/test_time_series_augmentations.py new file mode 100644 index 0000000..0c38494 --- /dev/null +++ b/tests/pipeline/model/training/augmentation/test_time_series_augmentations.py @@ -0,0 +1,166 @@ +import numpy as np +from epochalyst.pipeline.model.training.augmentation import time_series_augmentations +import torch + + +def set_torch_seed(seed: int = 42) -> None: + """Set torch seed for reproducibility. + + :param seed: seed to set + + :return: None + """ + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # When running on the CuDNN backend, two further options must be set + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +class TestTimeSeriesAugmentations: + def test_cutmix1d(self): + set_torch_seed(42) + cutmix1d = time_series_augmentations.CutMix1D(p=1.0) + # Create dummy input and labels + x = torch.cat([torch.ones(16, 1, 100), torch.zeros(16, 1, 100)], dim=0) + # Multiclass labels + y = torch.cat([torch.ones(16, 2), torch.zeros(16, 2)], dim=0) + # Apply CutMix augmentation + augmented_x, augmented_y = cutmix1d(x, y) + + # Assert the output shapes are correct + assert augmented_x.shape == x.shape + assert augmented_y.shape == y.shape + + # Because the images are all ones and zeros the mean of the pixels should be equal to the labels after being transformed + assert torch.allclose(augmented_x.mean(dim=-1), augmented_y) + + cutmix1d = time_series_augmentations.CutMix1D(p=0) + augmented_x, augmented_y = cutmix1d(x, y) + assert torch.all(augmented_x == x) & torch.all(augmented_y == y) + + def test_mixup1d(self): + set_torch_seed(42) + mixup1d = time_series_augmentations.MixUp1D(p=1.0) + # Create dummy input and labels + x = torch.cat([torch.ones(16, 1, 100), torch.zeros(16, 1, 100)], dim=0) + # Multiclass labels + y = torch.cat([torch.ones(16, 2), torch.zeros(16, 2)], dim=0) + # Apply CutMix augmentation + augmented_x, augmented_y = mixup1d(x, y) + + # Assert the output shapes are correct + assert augmented_x.shape == x.shape + assert augmented_y.shape == y.shape + + # Because the images are all ones and zeros the mean of the pixels should be equal to the labels after being transformed + assert torch.allclose(augmented_x.mean(dim=-1), augmented_y) + + mixup1d = time_series_augmentations.MixUp1D(p=0) + augmented_x, augmented_y = mixup1d(x, y) + assert torch.all(augmented_x == x) & torch.all(augmented_y == y) + + def test_mirror1d(self): + set_torch_seed(42) + mirror1d = time_series_augmentations.Mirror1D(p=1.0) + x = torch.cat([torch.ones(32, 1, 50), torch.zeros(32, 1, 50)], dim=-1) + + augmented_x = mirror1d(x) + + # Assert the output shape is correct + assert augmented_x.shape == x.shape + + # Assert x is mirrored + assert torch.allclose( + augmented_x, + torch.cat([torch.zeros(32, 1, 50), torch.ones(32, 1, 50)], dim=-1), + ) + + mirror1d = time_series_augmentations.Mirror1D(p=0) + augmented_x = mirror1d(x) + assert torch.all(augmented_x == x) + + def test_random_amplitude_shift(self): + set_torch_seed(42) + low = 0.5 + high = 1.5 + random_amplitude_shift = time_series_augmentations.RandomAmplitudeShift( + p=1.0, low=low, high=high + ) + # Sum of 2 signals with the 2nd one being half the frequency of the first one + x = torch.sin(torch.linspace(0, 2 * np.pi, 1000)) + torch.sin( + torch.linspace(0, np.pi, 1000) + ) + augmented_x = random_amplitude_shift(x) + + # Assert the output shape is correct + assert augmented_x.shape == x.shape + # Assert that the resulting signals amplitudes do not go over the bounds that have been set + assert torch.all( + torch.abs(torch.fft.rfft(x)) * low <= torch.abs(torch.fft.rfft(augmented_x)) + ) & torch.all( + torch.abs(torch.fft.rfft(augmented_x)) + <= torch.abs(torch.fft.rfft(x)) * high + ) + + random_amplitude_shift = time_series_augmentations.RandomAmplitudeShift(p=0) + augmented_x = random_amplitude_shift(x) + assert torch.all(augmented_x == x) + + def test_random_phase_shift(self): + set_torch_seed(42) + random_phase_shift = time_series_augmentations.RandomPhaseShift(p=1.0) + x = torch.sin(torch.linspace(0, 2 * np.pi, 1000)) + augmented_x = random_phase_shift(x) + + # Assert the output shape is correct + assert augmented_x.shape == x.shape + + # Assert x is not equal to augmented x + assert not torch.allclose(augmented_x, x) + # Aseert that the absolute value of the rfft is still the same. Very high atol beacuse sin function isn't precise with 1000 points + assert torch.allclose( + torch.abs(torch.fft.rfft(x, dim=-1)), + torch.abs(torch.fft.rfft(augmented_x, dim=-1)), + atol=0.05, + ) + # Assert that the mean is still around 0 and equal to the original mean + assert torch.isclose(augmented_x.mean(), x.mean()) + assert torch.isclose(augmented_x.mean(), torch.tensor([0]).float()) + + random_phase_shift = time_series_augmentations.RandomPhaseShift(p=0) + augmented_x = random_phase_shift(x) + assert torch.all(augmented_x == x) + + def test_reverse_1d(self): + set_torch_seed(42) + reverse1d = time_series_augmentations.Reverse1D(p=1.0) + x = torch.sin(torch.linspace(0, 2 * np.pi, 1000)).unsqueeze(0) + test_x = torch.sin(torch.linspace(np.pi, 3 * np.pi, 1000)).unsqueeze(0) + augmented_x = reverse1d(x) + + # Assert the output shape is correct + assert augmented_x.shape == x.shape + # Assert the reversed sine wave is equal to 180 degrees phase shifted version + assert torch.allclose(test_x, augmented_x, atol=0.0000005) + + reverse1d = time_series_augmentations.Reverse1D(p=0) + augmented_x = reverse1d(x) + assert torch.all(augmented_x == x) + + def test_subtract_channels(self): + set_torch_seed(42) + subtract_channels = time_series_augmentations.SubstractChannels(p=1.0) + # Only works for multi-channel sequences + x = torch.ones(32, 2, 100) + augmented_x = subtract_channels(x) + + # Assert the output shape is correct + assert augmented_x.shape == x.shape + + assert torch.allclose(torch.zeros(*augmented_x.shape), augmented_x) + + subtract_channels = time_series_augmentations.SubstractChannels(p=0) + augmented_x = subtract_channels(x) + assert torch.all(augmented_x == x) diff --git a/tests/pipeline/model/training/augmentation/test_utils.py b/tests/pipeline/model/training/augmentation/test_utils.py new file mode 100644 index 0000000..f792797 --- /dev/null +++ b/tests/pipeline/model/training/augmentation/test_utils.py @@ -0,0 +1,87 @@ +from epochalyst.pipeline.model.training.augmentation import utils +import torch + + +def set_torch_seed(seed: int = 42) -> None: + """Set torch seed for reproducibility. + + :param seed: seed to set + + :return: None + """ + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # When running on the CuDNN backend, two further options must be set + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +class TestUtils: + def test_no_op(self): + no_op = utils.NoOp() + x = torch.rand(4, 1, 100, 100) + augmented_x = no_op(x) + + assert torch.all(augmented_x == x) + + def test_custom_sequential(self): + class DummyXStep: + def __call__(self, x: torch.Tensor): + return x + 1 + + class DummyXYStep: + def __call__(self, x: torch.Tensor, y: torch.Tensor): + return x + 1, y + 1 + + step1 = DummyXStep() + step2 = DummyXYStep() + + sequential = utils.CustomSequential(x_transforms=[step1], xy_transforms=[step2]) + + x = torch.ones(32, 1, 100) + y = torch.zeros(32, 1) + augmented_x, augmented_y = sequential(x, y) + + assert torch.all(augmented_x == x + 2) + assert torch.all(augmented_y == y + 1) + + def test_custom_apply_one(self): + class DummyXStep: + def __init__(self, p): + self.p = p + + def __call__(self, x: torch.Tensor): + return x + 1 + + class DummyXYStep: + def __init__(self, p): + self.p = p + + def __call__(self, x: torch.Tensor, y: torch.Tensor): + return x, y + 1 + + set_torch_seed(42) + step1 = DummyXStep(p=0.33) + step2 = DummyXStep(p=0.33) + step3 = DummyXYStep(p=0.33) + + apply_one = utils.CustomApplyOne(x_transforms=[step1, step2]) + + x = torch.ones(32, 1, 1) + y = torch.zeros(32, 1) + augmented_x, augmented_y = apply_one(x, y) + + assert torch.all(augmented_x == x + 1) + + apply_one = utils.CustomApplyOne( + x_transforms=[step1, step2], xy_transforms=[step3] + ) + augmented_x = x + augmented_y = y + for _ in range(10000): + augmented_x, augmented_y = apply_one(augmented_x, augmented_y) + # Assert that the xy transform is applied roughly 1/3 of the time + assert torch.all(3300 <= augmented_y) & torch.all(augmented_y <= 3366) + # Assert that the x transform is applied roughly 2/3 of the time + assert torch.all(6633 <= augmented_x) & torch.all(augmented_x <= 6700)