Skip to content

Commit

Permalink
Merge pull request #148 from TeamEpochGithub/140-enh-add-dataloader-a…
Browse files Browse the repository at this point in the history
…rgs-to-init-of-main-trainer

Add dataloader args to init of main trainer
  • Loading branch information
tolgakopar authored May 16, 2024
2 parents 53ac91b + a571cdc commit ba84458
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.yaml jmvanselm@gmail.com
*.yml jmvanselm@gmail.com
*.toml jmvanselm@gmail.com
6 changes: 6 additions & 0 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class TorchTrainer(TrainingBlock):
- `model_name` (str): Name of the model
- `n_folds` (float): Number of folds for cross validation (0 for train full,
- `fold` (int): Fold number
- `dataloader_args (dict): Arguments for the dataloader`
Methods
-------
Expand Down Expand Up @@ -142,6 +143,8 @@ def log_to_terminal(self, message: str) -> None:
_fold: int = field(default=-1, init=False, repr=False, compare=False)
n_folds: float = field(default=-1, init=True, repr=False, compare=False)

dataloader_args: dict[str, Any] = field(default_factory=dict, repr=False)

def __post_init__(self) -> None:
"""Post init method for the TorchTrainer class."""
# Make sure to_predict is either "test" or "all" or "none"
Expand Down Expand Up @@ -397,6 +400,7 @@ def predict_on_loader(
collate_fn=(
collate_fn if hasattr(loader.dataset, "__getitems__") else None # type: ignore[arg-type]
),
**self.dataloader_args,
)
with torch.no_grad(), tqdm(loader, unit="batch", disable=False) as tepoch:
for data in tepoch:
Expand Down Expand Up @@ -473,12 +477,14 @@ def create_dataloaders(
batch_size=self.batch_size,
shuffle=True,
collate_fn=(collate_fn if hasattr(train_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
test_loader = DataLoader(
test_dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=(collate_fn if hasattr(test_dataset, "__getitems__") else None), # type: ignore[arg-type]
**self.dataloader_args,
)
return train_loader, test_loader

Expand Down

0 comments on commit ba84458

Please sign in to comment.