Skip to content

Commit

Permalink
Merge pull request #152 from TeamEpochGithub/v0.3.3
Browse files Browse the repository at this point in the history
V0.3.3
  • Loading branch information
justanotherariel authored May 16, 2024
2 parents 06f123c + 6d3f817 commit 085e8d6
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 23 deletions.
27 changes: 21 additions & 6 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from epochalyst._core._pipeline._custom_data_parallel import _CustomDataParallel
from epochalyst.logging.section_separator import print_section_separator
from epochalyst.pipeline.model.training.training_block import TrainingBlock
from epochalyst.pipeline.model.training.utils.tensor_functions import batch_to_device

T = TypeVar("T", bound=Dataset) # type: ignore[type-arg]
T_co = TypeVar("T_co", covariant=True)
Expand All @@ -43,6 +44,9 @@ class TorchTrainer(TrainingBlock):
- `model_name` (str): Name of the model
- `n_folds` (float): Number of folds for cross validation (0 for train full,
- `fold` (int): Fold number
- `dataloader_args (dict): Arguments for the dataloader`
- `x_tensor_type` (str): Type of x tensor for data
- `y_tensor_type` (str): Type of y tensor for labels
Methods
-------
Expand Down Expand Up @@ -142,6 +146,12 @@ def log_to_terminal(self, message: str) -> None:
_fold: int = field(default=-1, init=False, repr=False, compare=False)
n_folds: float = field(default=-1, init=True, repr=False, compare=False)

dataloader_args: dict[str, Any] = field(default_factory=dict, repr=False)

# Types for tensors
x_tensor_type: str = "float"
y_tensor_type: str = "float"

def __post_init__(self) -> None:
"""Post init method for the TorchTrainer class."""
# Make sure to_predict is either "test" or "all" or "none"
Expand Down Expand Up @@ -397,10 +407,11 @@ def predict_on_loader(
collate_fn=(
collate_fn if hasattr(loader.dataset, "__getitems__") else None # type: ignore[arg-type]
),
**self.dataloader_args,
)
with torch.no_grad(), tqdm(loader, unit="batch", disable=False) as tepoch:
for data in tepoch:
X_batch = data[0].to(self.device).float()
X_batch = batch_to_device(data[0], self.x_tensor_type, self.device)

y_pred = self.model(X_batch).squeeze(1).cpu().numpy()
predictions.extend(y_pred)
Expand Down Expand Up @@ -473,12 +484,14 @@ def create_dataloaders(
batch_size=self.batch_size,
shuffle=True,
collate_fn=(collate_fn if hasattr(train_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
test_loader = DataLoader(
test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=(collate_fn if hasattr(test_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
return train_loader, test_loader

Expand Down Expand Up @@ -601,8 +614,9 @@ def _train_one_epoch(
)
for batch in pbar:
X_batch, y_batch = batch
X_batch = X_batch.to(self.device).float()
y_batch = y_batch.to(self.device).float()

X_batch = batch_to_device(X_batch, self.x_tensor_type, self.device)
y_batch = batch_to_device(y_batch, self.x_tensor_type, self.device)

# Forward pass
y_pred = self.model(X_batch).squeeze(1)
Expand All @@ -619,7 +633,7 @@ def _train_one_epoch(

# Step the scheduler
if self.initialized_scheduler is not None:
self.initialized_scheduler.step(epoch=epoch)
self.initialized_scheduler.step(epoch=epoch + 1)

# Remove the cuda cache
torch.cuda.empty_cache()
Expand All @@ -644,8 +658,9 @@ def _val_one_epoch(
with torch.no_grad():
for batch in pbar:
X_batch, y_batch = batch
X_batch = X_batch.to(self.device).float()
y_batch = y_batch.to(self.device).float()

X_batch = batch_to_device(X_batch, self.x_tensor_type, self.device)
y_batch = batch_to_device(y_batch, self.y_tensor_type, self.device)

# Forward pass
y_pred = self.model(X_batch).squeeze(1)
Expand Down
5 changes: 4 additions & 1 deletion epochalyst/pipeline/model/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
x, y = super().train(x, y, **train_args)

if cache_args:
self.log_to_terminal(f"Storing cache for x and y to {cache_args['storage_path']}")
self._store_cache(name=self.get_hash() + "x", data=x, cache_args=cache_args)
self._store_cache(name=self.get_hash() + "y", data=y, cache_args=cache_args)

Expand Down Expand Up @@ -115,7 +116,9 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)

x = super().predict(x, **pred_args)

self._store_cache(self.get_hash() + "p", x, cache_args) if cache_args else None
if cache_args:
self.log_to_terminal(f"Storing cache for x to {cache_args['storage_path']}")
self._store_cache(self.get_hash() + "p", x, cache_args)

# Set steps to original in case class is called again
self.steps = self.all_steps
Expand Down
34 changes: 19 additions & 15 deletions epochalyst/pipeline/model/training/training_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg

x, y = self.custom_train(x, y, **train_args)

self._store_cache(
name=self.get_hash() + "x",
data=x,
cache_args=cache_args,
) if cache_args else None
self._store_cache(
name=self.get_hash() + "y",
data=y,
cache_args=cache_args,
) if cache_args else None
if cache_args:
self.log_to_terminal(f"Storing cache for x and y to {cache_args['storage_path']}")
self._store_cache(
name=self.get_hash() + "x",
data=x,
cache_args=cache_args,
)
self._store_cache(
name=self.get_hash() + "y",
data=y,
cache_args=cache_args,
)

return x, y

Expand Down Expand Up @@ -116,11 +118,13 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)

x = self.custom_predict(x, **pred_args)

self._store_cache(
name=self.get_hash() + "p",
data=x,
cache_args=cache_args,
) if cache_args else None
if cache_args:
self.log_to_terminal(f"Store cache for predictions to {cache_args['storage_path']}")
self._store_cache(
name=self.get_hash() + "p",
data=x,
cache_args=cache_args,
)

return x

Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/training/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module with utility functions for training."""
40 changes: 40 additions & 0 deletions epochalyst/pipeline/model/training/utils/tensor_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Module with tensor functions."""
import torch
from torch import Tensor


def batch_to_device(batch: Tensor, tensor_type: str, device: torch.device) -> Tensor:
"""Move batch to device with certain type.
:param batch: Batch to move
:param tensor_type: Type of the batch
:param device: Device to move the batch to
:return: The moved tensor
"""
type_conversion = {
"float": torch.float32,
"float32": torch.float32,
"float64": torch.float64,
"double": torch.float64,
"float16": torch.float16,
"half": torch.float16,
"int": torch.int32,
"int32": torch.int32,
"int64": torch.int64,
"long": torch.int64,
"int16": torch.int16,
"short": torch.int16,
"uint8": torch.uint8,
"byte": torch.uint8,
"int8": torch.int8,
"bfloat16": torch.bfloat16,
"bool": torch.bool,
}

if tensor_type in type_conversion:
dtype = type_conversion[tensor_type]
batch = batch.to(device, dtype=dtype)
else:
raise ValueError(f"Unsupported tensor type: {tensor_type}")

return batch
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/transformation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
data = super().transform(data, **transform_args)

if cache_args:
self.log_to_terminal(f"Storing cache for pipeline to {cache_args['storage_path']}")
self._store_cache(self.get_hash(), data, cache_args)

# Set steps to original in case class is called again
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_

data = self.custom_transform(data, **transform_args)
if cache_args:
self.log_to_terminal(f"Storing cache to {cache_args['storage_path']}")
self._store_cache(name=self.get_hash(), data=data, cache_args=cache_args)
return data

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "epochalyst"
version = "0.3.2"
version = "0.3.3"
authors = [
{ name = "Jasper van Selm", email = "jmvanselm@gmail.com" },
{ name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },
Expand Down

0 comments on commit 085e8d6

Please sign in to comment.