Skip to content

Commit

Permalink
Add **dataloader_args to dataloaders in torch trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
schobbejak committed May 16, 2024
2 parents 06f123c + 53ac91b commit eb1c207
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def log_to_terminal(self, message: str) -> None:
_fold: int = field(default=-1, init=False, repr=False, compare=False)
n_folds: float = field(default=-1, init=True, repr=False, compare=False)

dataloader_args: dict[str, Any] = field(default_factory=dict, repr=False)

def __post_init__(self) -> None:
"""Post init method for the TorchTrainer class."""
# Make sure to_predict is either "test" or "all" or "none"
Expand Down Expand Up @@ -397,6 +399,7 @@ def predict_on_loader(
collate_fn=(
collate_fn if hasattr(loader.dataset, "__getitems__") else None # type: ignore[arg-type]
),
**self.dataloader_args,
)
with torch.no_grad(), tqdm(loader, unit="batch", disable=False) as tepoch:
for data in tepoch:
Expand Down Expand Up @@ -473,12 +476,14 @@ def create_dataloaders(
batch_size=self.batch_size,
shuffle=True,
collate_fn=(collate_fn if hasattr(train_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
test_loader = DataLoader(
test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=(collate_fn if hasattr(test_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
return train_loader, test_loader

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "epochalyst"
version = "0.3.2"
version = "0.3.3"
authors = [
{ name = "Jasper van Selm", email = "jmvanselm@gmail.com" },
{ name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },
Expand Down

0 comments on commit eb1c207

Please sign in to comment.