Skip to content

Commit

Permalink
Merge pull request #214 from TeamEpochGithub/162-ref-shorten-import-p…
Browse files Browse the repository at this point in the history
…ath-lengths

162 | Shorten imports, reworked folders
  • Loading branch information
Jeffrey-Lim authored Jun 28, 2024
2 parents 914b262 + 307c43d commit e8a1b9c
Show file tree
Hide file tree
Showing 58 changed files with 150 additions and 107 deletions.
13 changes: 8 additions & 5 deletions epochalyst/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""The epochalyst package.
"""The epochalyst package."""

It consists of the following modules:
- `logging`: The logging module contains the classes and methods to log the pipeline.
- `pipeline`: The pipeline module contains the classes and methods to create a pipeline for the model.
"""
from .ensemble import EnsemblePipeline
from .model import ModelPipeline

__all__ = [
"ModelPipeline",
"EnsemblePipeline",
]
1 change: 0 additions & 1 deletion epochalyst/_core/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/_core/_caching/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/_core/_pipeline/__init__.py

This file was deleted.

5 changes: 5 additions & 0 deletions epochalyst/caching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Caching module for epochalyst."""

from .cacher import CacheArgs, Cacher

__all__ = ["Cacher", "CacheArgs"]
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""The cacher module contains the Cacher class."""

import glob
import os
import pickle
import sys
from typing import Any, Literal, TypedDict

from epochalyst.logging.logger import Logger
from epochalyst.logging import Logger

try:
import dask.array as da
Expand All @@ -27,7 +29,6 @@
except ImportError:
"""User doesn't require these packages"""


if sys.version_info < (3, 11): # pragma: no cover (<py311)
from typing_extensions import NotRequired
else: # pragma: no cover (py311+)
Expand Down Expand Up @@ -76,7 +77,7 @@ class CacheArgs(TypedDict):
store_args: NotRequired[dict[str, Any]]


class _Cacher(Logger):
class Cacher(Logger):
"""The cacher is a flexible class that allows for caching of any data.
The cacher uses cache_args to determine if the data is already cached and if so, return the cached data.
Expand Down
2 changes: 1 addition & 1 deletion epochalyst/pipeline/ensemble.py → epochalyst/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from agogos.training import ParallelTrainingSystem

from epochalyst._core._caching._cacher import CacheArgs
from epochalyst.caching import CacheArgs


class EnsemblePipeline(ParallelTrainingSystem):
Expand Down
6 changes: 5 additions & 1 deletion epochalyst/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Module for core logging functionality."""
"""Logging module, contains Logger class for logging messages to console and file."""

from .logger import Logger

__all__ = ["Logger"]
5 changes: 2 additions & 3 deletions epochalyst/logging/logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Logger base class."""
"""Logger base class for logging methods."""

import logging
import os
from collections.abc import Mapping
from typing import Any
from typing import Any, Mapping


class Logger:
Expand Down
4 changes: 2 additions & 2 deletions epochalyst/pipeline/model/model.py → epochalyst/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""ModelPipeline connects multiple transforming and training systems for extended training functionality."""
"""Model module. Contains the ModelPipeline class."""

from typing import Any

from agogos.training import Pipeline

from epochalyst._core._caching._cacher import CacheArgs
from epochalyst.caching import CacheArgs


class ModelPipeline(Pipeline):
Expand Down
1 change: 0 additions & 1 deletion epochalyst/pipeline/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/pipeline/model/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/pipeline/model/training/__init__.py

This file was deleted.

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/pipeline/model/training/models/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion epochalyst/pipeline/model/transformation/__init__.py

This file was deleted.

14 changes: 14 additions & 0 deletions epochalyst/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Module containing training functionality for the epochalyst package."""

from .pretrain_block import PretrainBlock
from .torch_trainer import TorchTrainer, TrainValidationDataset
from .training import TrainingPipeline
from .training_block import TrainingBlock

__all__ = [
"PretrainBlock",
"TrainingBlock",
"TorchTrainer",
"TrainingPipeline",
"TrainValidationDataset",
]
26 changes: 26 additions & 0 deletions epochalyst/training/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Module containing implementation for augmentations."""

from epochalyst.training.augmentation.image_augmentations import CutMix, MixUp
from epochalyst.training.augmentation.time_series_augmentations import (
AddBackgroundNoiseWrapper,
CutMix1D,
EnergyCutmix,
Mirror1D,
MixUp1D,
RandomAmplitudeShift,
RandomPhaseShift,
SubtractChannels,
)

__all__ = [
"CutMix",
"MixUp",
"CutMix1D",
"MixUp1D",
"Mirror1D",
"EnergyCutmix",
"RandomPhaseShift",
"RandomAmplitudeShift",
"SubtractChannels",
"AddBackgroundNoiseWrapper",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch

from epochalyst.pipeline.model.training.augmentation.utils import get_audiomentations
from .utils import get_audiomentations


@dataclass
Expand Down Expand Up @@ -187,7 +187,7 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:


@dataclass
class SubstractChannels(torch.nn.Module):
class SubtractChannels(torch.nn.Module):
"""Randomly substract other channels from the current one."""

p: float = 0.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from epochalyst.pipeline.model.training.utils.recursive_repr import recursive_repr
from epochalyst.training.utils.recursive_repr import recursive_repr


def get_audiomentations() -> ModuleType:
Expand Down
7 changes: 7 additions & 0 deletions epochalyst/training/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Module for reusable models or wrappers."""

from .timm import Timm

__all__ = [
"Timm",
]
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from joblib import hash

from epochalyst.pipeline.model.training.training_block import TrainingBlock
from .training_block import TrainingBlock


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,21 @@
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm

from epochalyst._core._pipeline._custom_data_parallel import _CustomDataParallel
from epochalyst.pipeline.model.training.training_block import TrainingBlock
from epochalyst.pipeline.model.training.utils import _get_onnxrt, _get_openvino
from epochalyst.pipeline.model.training.utils.tensor_functions import batch_to_device
from ._custom_data_parallel import _CustomDataParallel
from .training_block import TrainingBlock
from .utils import _get_onnxrt, _get_openvino, batch_to_device

T = TypeVar("T", bound=Dataset) # type: ignore[type-arg]
T_co = TypeVar("T_co", covariant=True)


def custom_collate(batch: tuple[Tensor, ...]) -> tuple[Tensor, ...]:
def custom_collate(batch: list[Tensor]) -> tuple[Tensor, Tensor]:
"""Collate function for the dataloader.
:param batch: The batch to collate.
:return: Collated batch.
"""
X, y = batch
X, y = batch[0], batch[1]
return X, y


Expand Down Expand Up @@ -167,7 +166,7 @@ def log_to_terminal(self, message: str) -> None:
epochs: Annotated[int, Gt(0)] = 10
patience: Annotated[int, Gt(0)] = -1 # Early stopping
batch_size: Annotated[int, Gt(0)] = 32
collate_fn: Callable[[tuple[Tensor, ...]], tuple[Tensor, ...]] = field(default=custom_collate, init=True, repr=False, compare=False)
collate_fn: Callable[[list[Tensor]], tuple[Tensor, Tensor]] = field(default=custom_collate, init=True, repr=False, compare=False)

# Checkpointing
checkpointing_enabled: bool = field(default=True, init=True, repr=False, compare=False)
Expand Down Expand Up @@ -379,9 +378,7 @@ def _predict_after_train(
concat_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=(
self.collate_fn if hasattr(concat_dataset, "__getitems__") else None # type: ignore[arg-type]
),
collate_fn=(self.collate_fn if hasattr(concat_dataset, "__getitems__") else None),
)
return self.predict_on_loader(pred_dataloader), y
case "validation":
Expand Down Expand Up @@ -413,7 +410,7 @@ def custom_predict(self, x: Any, **pred_args: Any) -> npt.NDArray[np.float32]:
pred_dataset,
batch_size=curr_batch_size,
shuffle=False,
collate_fn=(self.collate_fn if hasattr(pred_dataset, "__getitems__") else None), # type: ignore[arg-type]
collate_fn=(self.collate_fn if hasattr(pred_dataset, "__getitems__") else None),
)

# Predict with a single model
Expand Down Expand Up @@ -459,9 +456,7 @@ def predict_on_loader(
loader.dataset,
batch_size=loader.batch_size,
shuffle=False,
collate_fn=(
self.collate_fn if hasattr(loader.dataset, "__getitems__") else None # type: ignore[arg-type]
),
collate_fn=(self.collate_fn if hasattr(loader.dataset, "__getitems__") else None),
**self.dataloader_args,
)
if compile_method is None:
Expand Down Expand Up @@ -572,14 +567,14 @@ def create_dataloaders(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=(self.collate_fn if hasattr(train_dataset, "__getitems__") else None), # type: ignore[arg-type]
collate_fn=(self.collate_fn if hasattr(train_dataset, "__getitems__") else None),
**self.dataloader_args,
)
validation_loader = DataLoader(
validation_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=(self.collate_fn if hasattr(validation_dataset, "__getitems__") else None), # type: ignore[arg-type]
collate_fn=(self.collate_fn if hasattr(validation_dataset, "__getitems__") else None),
**self.dataloader_args,
)
return train_loader, validation_loader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from agogos.training import TrainingSystem, TrainType

from epochalyst._core._caching._cacher import CacheArgs, _Cacher
from epochalyst.caching import CacheArgs, Cacher


class TrainingPipeline(TrainingSystem, _Cacher):
class TrainingPipeline(TrainingSystem, Cacher):
"""The training pipeline. This is the class used to create the pipeline for the training of the model.
:param steps: The steps to train the model.
Expand Down Expand Up @@ -35,9 +35,9 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg

# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of _Cacher and if cache_args exists
if not isinstance(step, _Cacher) or not isinstance(step, TrainType):
self.log_to_debug(f"{step} is not instance of _Cacher or TrainType")
# Check if step is instance of Cacher and if cache_args exists
if not isinstance(step, Cacher) or not isinstance(step, TrainType):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue

step_args = train_args.get(step.__class__.__name__, None)
Expand Down Expand Up @@ -89,9 +89,9 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)

# Retrieve furthest step calculated
for i, step in enumerate(self.get_steps()):
# Check if step is instance of _Cacher and if cache_args exists
if not isinstance(step, _Cacher) or not isinstance(step, TrainType):
self.log_to_debug(f"{step} is not instance of _Cacher or TrainType")
# Check if step is instance of Cacher and if cache_args exists
if not isinstance(step, Cacher) or not isinstance(step, TrainType):
self.log_to_debug(f"{step} is not instance of Cacher or TrainType")
continue

step_args = pred_args.get(step.__class__.__name__, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from agogos.training import Trainer

from epochalyst._core._caching._cacher import CacheArgs, _Cacher
from epochalyst.caching import CacheArgs, Cacher


class TrainingBlock(Trainer, _Cacher):
class TrainingBlock(Trainer, Cacher):
"""The training block is a flexible block that allows for training of any model.
Methods
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Module with utility functions for training."""

from .get_dependencies import _get_onnxrt, _get_openvino
from .recursive_repr import recursive_repr
from .tensor_functions import batch_to_device

__all__ = [
"_get_onnxrt",
"_get_openvino",
"batch_to_device",
"recursive_repr",
]
File renamed without changes.
9 changes: 9 additions & 0 deletions epochalyst/transformation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Module containing transformation functions for the model pipeline."""

from .transformation import TransformationPipeline
from .transformation_block import TransformationBlock

__all__ = [
"TransformationPipeline",
"TransformationBlock",
]
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""TransformationPipeline that extends from TransformingSystem, _Cacher and _Logger."""
"""TransformationPipeline that extends from TransformingSystem, Cacher and _Logger."""

from dataclasses import dataclass
from typing import Any

from agogos.transforming import TransformingSystem, TransformType

from epochalyst._core._caching._cacher import CacheArgs, _Cacher
from epochalyst.caching.cacher import CacheArgs, Cacher


@dataclass
class TransformationPipeline(TransformingSystem, _Cacher):
class TransformationPipeline(TransformingSystem, Cacher):
"""TransformationPipeline is the class used to create the pipeline for the transformation of the data.
### Parameters:
Expand Down Expand Up @@ -80,9 +80,9 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_

# Furthest step
for i, step in enumerate(self.get_steps()):
# Check if step is instance of _Cacher and if cache_args exists
if not isinstance(step, _Cacher) or not isinstance(step, TransformType):
self.log_to_debug(f"{step} is not instance of _Cacher or TransformType")
# Check if step is instance of Cacher and if cache_args exists
if not isinstance(step, Cacher) or not isinstance(step, TransformType):
self.log_to_debug(f"{step} is not instance of Cacher or TransformType")
continue

step_args = transform_args.get(step.__class__.__name__, None)
Expand Down
Loading

0 comments on commit e8a1b9c

Please sign in to comment.