Skip to content

Commit

Permalink
Merge pull request #88 from TeamEpochGithub/v0.1.6
Browse files Browse the repository at this point in the history
V0.1.6
  • Loading branch information
hjdeheer authored Mar 19, 2024
2 parents 49bfad6 + 2fb765f commit 8ea45fe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
35 changes: 24 additions & 11 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def custom_train(
:param test_indices: The indices to test on.
:param cache_size: The cache size.
:param save_model: Whether to save the model.
:param fold: Fold number if running cv
:return: The input and output of the system.
"""
train_indices = train_args.get("train_indices")
Expand All @@ -187,6 +188,7 @@ def custom_train(
raise ValueError("test_indices not provided")
cache_size = train_args.get("cache_size", -1)
save_model = train_args.get("save_model", True)
fold = train_args.get("fold", -1)

self.save_model_to_disk = save_model
if self._model_exists():
Expand All @@ -213,16 +215,13 @@ def custom_train(
train_losses: list[float] = []
val_losses: list[float] = []

self.external_define_metric("Training/Train Loss", "min")
self.external_define_metric("Validation/Validation Loss", "min")

self.lowest_val_loss = np.inf
if len(test_loader) == 0:
self.log_to_warning(
f"Doing train full, model will be trained for {self.epochs} epochs"
)

self._training_loop(train_loader, test_loader, train_losses, val_losses)
self._training_loop(train_loader, test_loader, train_losses, val_losses, fold)

self.log_to_terminal(
f"Done training the model: {self.model.__class__.__name__}"
Expand Down Expand Up @@ -359,6 +358,7 @@ def _training_loop(
test_loader: DataLoader[tuple[Tensor, ...]],
train_losses: list[float],
val_losses: list[float],
fold: int = -1,
) -> None:
"""Training loop for the model.
Expand All @@ -367,6 +367,14 @@ def _training_loop(
:param train_losses: List of train losses.
:param val_losses: List of validation losses.
"""
fold_no = ""

if fold > -1:
fold_no = f"_{fold}"

self.external_define_metric(f"Training/Train Loss{fold_no}", "epoch")
self.external_define_metric(f"Validation/Validation Loss{fold_no}", "epoch")

for epoch in range(self.epochs):
# Train using train_loader
train_loss = self._train_one_epoch(train_loader, epoch)
Expand All @@ -375,7 +383,10 @@ def _training_loop(

# Log train loss
self.log_to_external(
message={"Training/Train Loss": train_losses[-1]}, step=epoch + 1
message={
f"Training/Train Loss{fold_no}": train_losses[-1],
"epoch": epoch,
}
)

# Compute validation loss
Expand All @@ -388,8 +399,10 @@ def _training_loop(

# Log validation loss and plot train/val loss against each other
self.log_to_external(
message={"Validation/Validation Loss": val_losses[-1]},
step=epoch + 1,
message={
f"Validation/Validation Loss{fold_no}": val_losses[-1],
"epoch": epoch,
}
)

self.log_to_external(
Expand All @@ -401,8 +414,8 @@ def _training_loop(
range(epoch + 1)
), # Ensure it's a list, not a range object
"ys": [train_losses, val_losses],
"keys": ["Train", "Validation"],
"title": "Training/Loss",
"keys": [f"Train{fold_no}", f"Validation{fold_no}"],
"title": f"Training/Loss{fold_no}",
"xname": "Epoch",
},
}
Expand All @@ -411,12 +424,12 @@ def _training_loop(
# Early stopping
if self._early_stopping():
self.log_to_external(
message={"Epochs": (epoch + 1) - self.patience}
message={f"Epochs{fold_no}": (epoch + 1) - self.patience}
)
break

# Log the trained epochs to wandb if we finished training
self.log_to_external(message={"Epochs": epoch + 1})
self.log_to_external(message={f"Epochs{fold_no}": epoch + 1})

def _train_one_epoch(
self, dataloader: DataLoader[tuple[Tensor, ...]], epoch: int
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "epochalyst"
version = "0.1.5"
version = "0.1.6"
authors = [
{ name = "Jasper van Selm", email = "jmvanselm@gmail.com" },
{ name = "Ariel Ebersberger", email = "arielebersberger@gmail.com" },
Expand Down

0 comments on commit 8ea45fe

Please sign in to comment.