Skip to content

Commit

Permalink
Merge pull request #93 from TeamEpochGithub/v0.1.7
Browse files Browse the repository at this point in the history
V0.1.7
  • Loading branch information
schobbejak authored Mar 22, 2024
2 parents 8ea45fe + 17cdb91 commit 896ee00
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
37 changes: 18 additions & 19 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Annotated, Any, Callable, List, TypeVar


import numpy as np
import numpy.typing as npt
import torch
Expand Down Expand Up @@ -191,15 +190,6 @@ def custom_train(
fold = train_args.get("fold", -1)

self.save_model_to_disk = save_model
if self._model_exists():
self.log_to_terminal(
f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model"
)
self._load_model()
return self.custom_predict(x), y

self.log_to_terminal(f"Training model: {self.model.__class__.__name__}")
self.log_to_debug(f"Training model: {self.model.__class__.__name__}")

# Create datasets
train_dataset, test_dataset = self.create_datasets(
Expand All @@ -209,6 +199,24 @@ def custom_train(
# Create dataloaders
train_loader, test_loader = self.create_dataloaders(train_dataset, test_dataset)

concat_dataset: Dataset[Any] = self._concat_datasets(
train_dataset, test_dataset, train_indices, test_indices
)
pred_dataloader = DataLoader(
concat_dataset, batch_size=self.batch_size, shuffle=False
)

if self._model_exists():
self.log_to_terminal(
f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model"
)
self._load_model()
# Return the predictions
return self.predict_on_loader(pred_dataloader), y

self.log_to_terminal(f"Training model: {self.model.__class__.__name__}")
self.log_to_debug(f"Training model: {self.model.__class__.__name__}")

# Train the model
self.log_to_terminal(f"Training model for {self.epochs} epochs")

Expand Down Expand Up @@ -237,15 +245,6 @@ def custom_train(
if save_model:
self._save_model()

# Return the predictions
concat_dataset: Dataset[Any] = self._concat_datasets(
train_dataset, test_dataset, train_indices, test_indices
)

pred_dataloader = DataLoader(
concat_dataset, batch_size=self.batch_size, shuffle=False
)

return self.predict_on_loader(pred_dataloader), y

def custom_predict(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "epochalyst"
version = "0.1.6"
version = "0.1.7"
authors = [
{ name = "Jasper van Selm", email = "jmvanselm@gmail.com" },
{ name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },
Expand Down

0 comments on commit 896ee00

Please sign in to comment.