diff --git a/epochalyst/pipeline/model/training/torch_trainer.py b/epochalyst/pipeline/model/training/torch_trainer.py index 494b3e6..4af6790 100644 --- a/epochalyst/pipeline/model/training/torch_trainer.py +++ b/epochalyst/pipeline/model/training/torch_trainer.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Annotated, Any, Callable, List, TypeVar - import numpy as np import numpy.typing as npt import torch @@ -191,15 +190,6 @@ def custom_train( fold = train_args.get("fold", -1) self.save_model_to_disk = save_model - if self._model_exists(): - self.log_to_terminal( - f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model" - ) - self._load_model() - return self.custom_predict(x), y - - self.log_to_terminal(f"Training model: {self.model.__class__.__name__}") - self.log_to_debug(f"Training model: {self.model.__class__.__name__}") # Create datasets train_dataset, test_dataset = self.create_datasets( @@ -209,6 +199,24 @@ def custom_train( # Create dataloaders train_loader, test_loader = self.create_dataloaders(train_dataset, test_dataset) + concat_dataset: Dataset[Any] = self._concat_datasets( + train_dataset, test_dataset, train_indices, test_indices + ) + pred_dataloader = DataLoader( + concat_dataset, batch_size=self.batch_size, shuffle=False + ) + + if self._model_exists(): + self.log_to_terminal( + f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model" + ) + self._load_model() + # Return the predictions + return self.predict_on_loader(pred_dataloader), y + + self.log_to_terminal(f"Training model: {self.model.__class__.__name__}") + self.log_to_debug(f"Training model: {self.model.__class__.__name__}") + # Train the model self.log_to_terminal(f"Training model for {self.epochs} epochs") @@ -237,15 +245,6 @@ def custom_train( if save_model: self._save_model() - # Return the predictions - concat_dataset: Dataset[Any] = self._concat_datasets( - train_dataset, test_dataset, train_indices, test_indices - ) - - pred_dataloader = DataLoader( - concat_dataset, batch_size=self.batch_size, shuffle=False - ) - return self.predict_on_loader(pred_dataloader), y def custom_predict( diff --git a/pyproject.toml b/pyproject.toml index 3a96d91..f695243 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "epochalyst" -version = "0.1.6" +version = "0.1.7" authors = [ { name = "Jasper van Selm", email = "jmvanselm@gmail.com" }, { name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },