From f5a3b646ac9ff7ad3aab29f16ab089ec5261b388 Mon Sep 17 00:00:00 2001 From: Jasper van Selm <70692744+schobbejak@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:48:36 +0200 Subject: [PATCH] Fix collate_fn changing hash of torch trainer. --- epochalyst/pipeline/model/training/torch_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/epochalyst/pipeline/model/training/torch_trainer.py b/epochalyst/pipeline/model/training/torch_trainer.py index 95da046..3cf3790 100644 --- a/epochalyst/pipeline/model/training/torch_trainer.py +++ b/epochalyst/pipeline/model/training/torch_trainer.py @@ -165,7 +165,7 @@ def log_to_terminal(self, message: str) -> None: epochs: Annotated[int, Gt(0)] = 10 patience: Annotated[int, Gt(0)] = 5 # Early stopping batch_size: Annotated[int, Gt(0)] = 32 - collate_fn: Callable[[tuple[Tensor, ...]], tuple[Tensor, ...]] = custom_collate + collate_fn: Callable[[tuple[Tensor, ...]], tuple[Tensor, ...]] = field(default=custom_collate, init=True, repr=False, compare=False) # Checkpointing checkpointing_enabled: bool = field(default=True, init=True, repr=False, compare=False)