Skip to content

Commit

Permalink
Merge pull request #92 from TeamEpochGithub/91-high-prio-refactor-tor…
Browse files Browse the repository at this point in the history
…ch-trainer-to-run-concat_dataset-when-model-is-loaded

Update torch_trainer.py
  • Loading branch information
schobbejak authored Mar 22, 2024
2 parents 940ee0f + de07cad commit 17cdb91
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Annotated, Any, Callable, List, TypeVar


import numpy as np
import numpy.typing as npt
import torch
Expand Down Expand Up @@ -191,15 +190,6 @@ def custom_train(
fold = train_args.get("fold", -1)

self.save_model_to_disk = save_model
if self._model_exists():
self.log_to_terminal(
f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model"
)
self._load_model()
return self.custom_predict(x), y

self.log_to_terminal(f"Training model: {self.model.__class__.__name__}")
self.log_to_debug(f"Training model: {self.model.__class__.__name__}")

# Create datasets
train_dataset, test_dataset = self.create_datasets(
Expand All @@ -209,6 +199,24 @@ def custom_train(
# Create dataloaders
train_loader, test_loader = self.create_dataloaders(train_dataset, test_dataset)

concat_dataset: Dataset[Any] = self._concat_datasets(
train_dataset, test_dataset, train_indices, test_indices
)
pred_dataloader = DataLoader(
concat_dataset, batch_size=self.batch_size, shuffle=False
)

if self._model_exists():
self.log_to_terminal(
f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model"
)
self._load_model()
# Return the predictions
return self.predict_on_loader(pred_dataloader), y

self.log_to_terminal(f"Training model: {self.model.__class__.__name__}")
self.log_to_debug(f"Training model: {self.model.__class__.__name__}")

# Train the model
self.log_to_terminal(f"Training model for {self.epochs} epochs")

Expand Down Expand Up @@ -237,15 +245,6 @@ def custom_train(
if save_model:
self._save_model()

# Return the predictions
concat_dataset: Dataset[Any] = self._concat_datasets(
train_dataset, test_dataset, train_indices, test_indices
)

pred_dataloader = DataLoader(
concat_dataset, batch_size=self.batch_size, shuffle=False
)

return self.predict_on_loader(pred_dataloader), y

def custom_predict(
Expand Down

0 comments on commit 17cdb91

Please sign in to comment.