diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..1ec9ffd --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +*.yaml jmvanselm@gmail.com +*.yml jmvanselm@gmail.com +*.toml jmvanselm@gmail.com diff --git a/epochalyst/pipeline/model/training/torch_trainer.py b/epochalyst/pipeline/model/training/torch_trainer.py index 99e478f..adeb326 100644 --- a/epochalyst/pipeline/model/training/torch_trainer.py +++ b/epochalyst/pipeline/model/training/torch_trainer.py @@ -43,6 +43,7 @@ class TorchTrainer(TrainingBlock): - `model_name` (str): Name of the model - `n_folds` (float): Number of folds for cross validation (0 for train full, - `fold` (int): Fold number + - `dataloader_args (dict): Arguments for the dataloader` Methods ------- @@ -142,6 +143,8 @@ def log_to_terminal(self, message: str) -> None: _fold: int = field(default=-1, init=False, repr=False, compare=False) n_folds: float = field(default=-1, init=True, repr=False, compare=False) + dataloader_args: dict[str, Any] = field(default_factory=dict, repr=False) + def __post_init__(self) -> None: """Post init method for the TorchTrainer class.""" # Make sure to_predict is either "test" or "all" or "none" @@ -397,6 +400,7 @@ def predict_on_loader( collate_fn=( collate_fn if hasattr(loader.dataset, "__getitems__") else None # type: ignore[arg-type] ), + **self.dataloader_args, ) with torch.no_grad(), tqdm(loader, unit="batch", disable=False) as tepoch: for data in tepoch: @@ -473,12 +477,14 @@ def create_dataloaders( batch_size=self.batch_size, shuffle=True, collate_fn=(collate_fn if hasattr(train_dataset, "__getitems__") else None), # type: ignore[arg-type] + **self.dataloader_args, ) test_loader = DataLoader( test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=(collate_fn if hasattr(test_dataset, "__getitems__") else None), # type: ignore[arg-type] + **self.dataloader_args, ) return train_loader, test_loader