Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v0.3' into 112-ref-add-argument-…
Browse files Browse the repository at this point in the history
…in-torch-trainer-whether-to-predict-or-not
  • Loading branch information
hjdeheer committed Apr 16, 2024
2 parents e01e44f + 9f147dd commit 84229b0
Show file tree
Hide file tree
Showing 11 changed files with 745 additions and 9 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ repos:
- torch
- traitlets
- timm
- kornia
args: [ --disallow-any-generics, --disallow-untyped-defs, --disable-error-code=import-untyped]
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ For caching some imports are only required, these have to be manually installed
- pyarrow >= 6.0.0 (Read parquet files)
- annotated-types >= 0.6.0

### Model

There is support for using timm models. To be able to do so the user must manually install timm.
- timm >= 0.9.16

### Augmentation

There is also implementations of augmentations that are not in commonly used packages. Most of these are for time series data but there are implmenetations for CutMix and MixUp for images that can be used in the pipeline. To be able to use these the user must manually install kornia.

- kornia >= 0.7.2

## Documentation

Documentation is generated using [Sphinx](https://www.sphinx-doc.org/en/master/).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module containing implementation for augmentations."""
103 changes: 103 additions & 0 deletions epochalyst/pipeline/model/training/augmentation/image_augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Contains implementation of several image augmentations using PyTorch."""

from dataclasses import dataclass, field
from typing import Any

import torch


def get_kornia_mix() -> Any: # noqa: ANN401
"""Return kornia mix."""
try:
import kornia

except ImportError:
raise ImportError(
"If you want to use this augmentation you must install kornia",
) from None

else:
return kornia.augmentation._2d.mix # noqa: SLF001


@dataclass
class CutMix:
"""2D CutMix implementation for spectrogram data augmentation.
:param cut_size: The size of the cut
:param same_on_batch: Apply the same transformation across the batch
:param p: The probability of applying the filter
"""

cut_size: tuple[float, float] = field(default=(0.0, 1.0))
same_on_batch: bool = False
p: float = 0.5

def __post_init__(self) -> None:
"""Check if the filter type is valid."""
self.cutmix = get_kornia_mix().cutmix.RandomCutMixV2(
p=self.p,
cut_size=self.cut_size,
same_on_batch=self.same_on_batch,
data_keys=["input", "class"],
)

def __call__(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Randomly patch the input with another sample.
:param x: Input images. (N,C,W,H)
:param y: Input labels. (N,C)
"""
dummy_labels = torch.arange(x.size(0))
augmented_x, augmentation_info = self.cutmix(x, dummy_labels)
augmentation_info = augmentation_info[0]

y = y.float()
y_result = y.clone()
for i in range(augmentation_info.shape[0]):
y_result[i] = y[i] * (1 - augmentation_info[i, 2]) + y[int(augmentation_info[i, 1])] * augmentation_info[i, 2]

return augmented_x, y_result


@dataclass
class MixUp:
"""2D MixUp implementation for spectrogram data augmentation.
:param lambda_val: The range of the mixup coefficient
:param same_on_batch: Apply the same transformation across the batch
:param p: The probability of applying the filter
"""

lambda_val: tuple[float, float] = field(default=(0.0, 1.0))
same_on_batch: bool = False
p: float = 0.5

def __post_init__(self) -> None:
"""Check if the filter type is valid."""
self.mixup = get_kornia_mix().mixup.RandomMixUpV2(
p=self.p,
lambda_val=self.lambda_val,
same_on_batch=self.same_on_batch,
data_keys=["input", "class"],
)

def __call__(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Randomly patch the input with another sample."""
dummy_labels = torch.arange(x.size(0))
augmented_x, augmentation_info = self.mixup(x, dummy_labels)

y = y.float()
y_result = y.clone()
for i in range(augmentation_info.shape[0]):
y_result[i] = y[i] * (1 - augmentation_info[i, 2]) + y[int(augmentation_info[i, 1])] * augmentation_info[i, 2]

return augmented_x, y_result
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
"""Contains implementation of several custom time series augmentations using PyTorch."""

from dataclasses import dataclass

import numpy as np
import torch


@dataclass
class CutMix1D(torch.nn.Module):
"""CutMix augmentation for 1D signals.
Randomly select a percentage between 'low' and 'high' to preserve on the left side of the signal.
The right side will be replaced by the corresponding range from another sample from the batch.
The labels become the weighted average of the mixed signals where weights are the mix ratios.
"""

p: float = 0.5
low: float = 0
high: float = 1

def __call__(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Appply CutMix to the batch of 1D signal.
:param x: Input features. (N,C,L)
:param y: Input labels. (N,C)
:return: The augmented features and labels
"""
indices = torch.arange(x.shape[0], device=x.device, dtype=torch.int)
shuffled_indices = torch.randperm(indices.shape[0])

low_len = int(self.low * x.shape[-1])
high_len = int(self.high * x.shape[-1])
cutoff_indices = torch.randint(
low_len,
high_len,
(x.shape[-1],),
device=x.device,
dtype=torch.int,
)
cutoff_rates = cutoff_indices.float() / x.shape[-1]

augmented_x = x.clone()
augmented_y = y.clone().float()
for i in range(x.shape[0]):
if torch.rand(1) < self.p:
augmented_x[i, :, cutoff_indices[i] :] = x[
shuffled_indices[i],
:,
cutoff_indices[i] :,
]
augmented_y[i] = y[i] * cutoff_rates[i] + y[shuffled_indices[i]] * (1 - cutoff_rates[i])
return augmented_x, augmented_y


@dataclass
class MixUp1D(torch.nn.Module):
"""MixUp augmentation for 1D signals.
Randomly takes the weighted average of 2 samples and their labels with random weights.
"""

p: float = 0.5

def __call__(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Appply MixUp to the batch of 1D signal.
:param x: Input features. (N,C,L)|(N,L)
:param y: Input labels. (N,C)
:return: The augmented features and labels
"""
indices = torch.arange(x.shape[0], device=x.device, dtype=torch.int)
shuffled_indices = torch.randperm(indices.shape[0])

augmented_x = x.clone()
augmented_y = y.clone().float()
for i in range(x.shape[0]):
if torch.rand(1) < self.p:
lambda_ = torch.rand(1, device=x.device)
augmented_x[i] = lambda_ * x[i] + (1 - lambda_) * x[shuffled_indices[i]]
augmented_y[i] = lambda_ * y[i] + (1 - lambda_) * y[shuffled_indices[i]]
return augmented_x, augmented_y


@dataclass
class Mirror1D(torch.nn.Module):
"""Mirror augmentation for 1D signals.
Mirrors the signal around its mean in the horizontal(time) axis.
"""

p: float = 0.5

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the augmentation to the input signal.
:param x: Input features. (N,C,L)|(N,L)
:return: Augmented features. (N,C,L)|(N,L)
"""
augmented_x = x.clone()
for i in range(x.shape[0]):
if torch.rand(1) < self.p:
augmented_x[i] = -1 * x[i] + 2 * x[i].mean(dim=-1).unsqueeze(-1)
return augmented_x


@dataclass
class RandomAmplitudeShift(torch.nn.Module):
"""Randomly scale the amplitude of all the frequencies in the signal."""

low: float = 0.5
high: float = 1.5
p: float = 0.5

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the augmentation to the input signal.
:param x: Input features. (N,C,L)|(N,L)
:return: Augmented features. (N,C,L)|(N,L)
"""
if torch.rand(1) < self.p:
# Take the rfft of the input tensor
x_freq = torch.fft.rfft(x, dim=-1)
# Create a random tensor of scaler in the range [low,high]
random_amplitude = torch.rand(*x_freq.shape, device=x.device, dtype=x.dtype) * (self.high - self.low) + self.low
# Multiply the rfft with the random amplitude
x_freq = x_freq * random_amplitude
# Take the irfft of the result
return torch.fft.irfft(x_freq, dim=-1)
return x


@dataclass
class RandomPhaseShift(torch.nn.Module):
"""Randomly shift the phase of all the frequencies in the signal."""

shift_limit: float = 0.25
p: float = 0.5

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Random phase shift to each frequency of the fft of the input signal.
:param x: Input features. (N,C,L)|(N,L)|(L)
:return: augmented features. (N,C,L)|(N,L)|(L)
"""
if torch.rand(1) < self.p:
# Take the rfft of the input tensor
x_freq = torch.fft.rfft(x, dim=-1)
# Create a random tensor of complex numbers each with a random phase but with magnitude of 1
random_phase = torch.rand(*x_freq.shape, device=x.device, dtype=x.dtype) * 2 * np.pi * self.shift_limit
random_phase = torch.cos(random_phase) + 1j * torch.sin(random_phase)
# Multiply the rfft with the random phase
x_freq = x_freq * random_phase
# Take the irfft of the result
return torch.fft.irfft(x_freq, dim=-1)
return x


@dataclass
class Reverse1D(torch.nn.Module):
"""Reverse augmentation for 1D signals."""

p: float = 0.5

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply the augmentation to the input signal.
:param x: Input features. (N,C,L)|(N,L)
:return: Augmented features (N,C,L)|(N,L)
"""
augmented_x = x.clone()
for i in range(x.shape[0]):
if torch.rand(1) < self.p:
augmented_x[i] = torch.flip(x[i], [-1])
return augmented_x


@dataclass
class SubstractChannels(torch.nn.Module):
"""Randomly substract other channels from the current one."""

p: float = 0.5

def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Apply substracting other channels to the input signal.
:param x: Input features. (N,C,L)
:return: Augmented features. (N,C,L)
"""
if x.shape[1] == 1:
raise ValueError(
"Sequence only has 1 channel. No channels to subtract from each other",
)
if torch.rand(1) < self.p:
length = x.shape[1] - 1
total = x.sum(dim=1) / length
x = x - total.unsqueeze(1) + (x / length)
return x
Loading

0 comments on commit 84229b0

Please sign in to comment.