Skip to content

Commit

Permalink
Update self.model_directory to private
Browse files Browse the repository at this point in the history
  • Loading branch information
schobbejak committed Apr 17, 2024
1 parent 1f8cd4a commit 36ffd82
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
36 changes: 21 additions & 15 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TorchTrainer(TrainingBlock):
- `batch_size` (int): Batch size
- `patience` (int): Patience for early stopping
- `test_size` (float): Relative size of the test set
- `to_predict` (str): Whether to predict on the test set, all data or none
- `to_predict` (str): Whether to predict on the 'test' set, 'all' data or 'none'
- `model_name` (str): Name of the model
- `n_folds` (float): Number of folds for cross validation (0 for train full,
- `fold` (int): Fold number
Expand Down Expand Up @@ -153,7 +153,7 @@ def __post_init__(self) -> None:
"Please specify the number of folds for cross validation or set n_folds to 0 for train full.",
)
self.save_model_to_disk = True
self.model_directory = "tm"
self._model_directory = Path("tm")
self.best_model_state_dict: dict[Any, Any] = {}

# Set optimizer
Expand Down Expand Up @@ -223,7 +223,7 @@ def custom_train(self, x: npt.NDArray[np.float32], y: npt.NDArray[np.float32], *

if self._model_exists():
self.log_to_terminal(
f"Model exists in {self.model_directory}/{self.get_hash()}.pt, loading model",
f"Model exists in {self._model_directory}/{self.get_hash()}.pt, loading model",
)
self._load_model()
# Return the predictions
Expand Down Expand Up @@ -471,12 +471,18 @@ def create_dataloaders(
)
return train_loader, test_loader

def update_model_directory(self, model_directory: str) -> None:
def update_model_directory(self, model_directory: Path) -> None:
"""Update the model directory.
:param model_directory: The model directory.
"""
self.model_directory = model_directory
if model_directory.exists() and model_directory.is_dir():
self._model_directory = model_directory
elif not model_directory.exists():
model_directory.mkdir()
self._model_directory = model_directory
else:
raise ValueError(f"{model_directory} is not a valid model_directory")

def save_model_to_external(self) -> None:
"""Save model to external database."""
Expand Down Expand Up @@ -643,31 +649,31 @@ def _val_one_epoch(
def _save_model(self) -> None:
"""Save the model in the model_directory folder."""
self.log_to_terminal(
f"Saving model to {self.model_directory}/{self.get_hash()}.pt",
f"Saving model to {self._model_directory}/{self.get_hash()}.pt",
)
path = Path(self.model_directory)
path = Path(self._model_directory)
if not Path.exists(path):
Path.mkdir(path)

torch.save(self.model, f"{self.model_directory}/{self.get_hash()}.pt")
torch.save(self.model, f"{self._model_directory}/{self.get_hash()}.pt")
self.log_to_terminal(
f"Model saved to {self.model_directory}/{self.get_hash()}.pt",
f"Model saved to {self._model_directory}/{self.get_hash()}.pt",
)
self.save_model_to_external()

def _load_model(self) -> None:
"""Load the model from the model_directory folder."""
# Check if the model exists
if not Path(f"{self.model_directory}/{self.get_hash()}.pt").exists():
if not Path(f"{self._model_directory}/{self.get_hash()}.pt").exists():
raise FileNotFoundError(
f"Model not found in {self.model_directory}/{self.get_hash()}.pt",
f"Model not found in {self._model_directory}/{self.get_hash()}.pt",
)

# Load model
self.log_to_terminal(
f"Loading model from {self.model_directory}/{self.get_hash()}.pt",
f"Loading model from {self._model_directory}/{self.get_hash()}.pt",
)
checkpoint = torch.load(f"{self.model_directory}/{self.get_hash()}.pt")
checkpoint = torch.load(f"{self._model_directory}/{self.get_hash()}.pt")

# Load the weights from the checkpoint
if isinstance(checkpoint, nn.DataParallel):
Expand All @@ -682,12 +688,12 @@ def _load_model(self) -> None:
self.model.load_state_dict(model.state_dict())

self.log_to_terminal(
f"Model loaded from {self.model_directory}/{self.get_hash()}.pt",
f"Model loaded from {self._model_directory}/{self.get_hash()}.pt",
)

def _model_exists(self) -> bool:
"""Check if the model exists in the model_directory folder."""
return Path(f"{self.model_directory}/{self.get_hash()}.pt").exists() and self.save_model_to_disk
return Path(f"{self._model_directory}/{self.get_hash()}.pt").exists() and self.save_model_to_disk

def _early_stopping(self) -> bool:
"""Check if early stopping should be performed.
Expand Down
40 changes: 21 additions & 19 deletions tests/pipeline/model/training/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Any
from unittest.mock import patch
from pathlib import Path

import numpy as np
import torch
Expand All @@ -16,6 +17,7 @@ class TestTorchTrainer:
simple_model = torch.nn.Linear(1, 1)
optimizer = functools.partial(torch.optim.SGD, lr=0.01)
scheduler = functools.partial(torch.optim.lr_scheduler.StepLR, step_size=1)
model_path = Path("tests/cache")

class ImplementedTorchTrainer(TorchTrainer):
def __post_init__(self):
Expand Down Expand Up @@ -171,7 +173,7 @@ def test_train(self):
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand All @@ -184,7 +186,7 @@ def test_train_trained(self):
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand All @@ -199,7 +201,7 @@ def test_train_full(self):
optimizer=self.optimizer,
)
tt.n_folds = 0
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)

Expand All @@ -215,7 +217,7 @@ def test_early_stopping(self):
optimizer=self.optimizer,
patience=-1,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], test_indices=[])
Expand All @@ -238,7 +240,7 @@ def test_predict(self):
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(
Expand All @@ -256,7 +258,7 @@ def test_predict_3fold(self):
)
remove_cache_files()
tt.n_folds = 3
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(
Expand All @@ -280,7 +282,7 @@ def test_predict_train_full(self):
)
remove_cache_files()
tt.n_folds = 0
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[])
Expand All @@ -294,7 +296,7 @@ def test_predict2(self):
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(
Expand All @@ -314,7 +316,7 @@ def test_predict_all(self):
optimizer=self.optimizer,
to_predict="all",
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
train_preds = tt.train(
Expand All @@ -335,14 +337,14 @@ def test_predict_2d(self):
optimizer=self.optimizer,
to_predict="all",
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 2)
y = torch.rand(10, 2)
train_preds = tt.train(
x,
y,
train_indices=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
test_indices=np.array([8, 9]),fold=0
test_indices=np.array([8, 9]), fold=0
)
preds = tt.predict(x)
assert len(train_preds[0]) == 10
Expand All @@ -357,14 +359,14 @@ def test_predict_partial(self):
optimizer=self.optimizer,
to_predict="test",
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
train_preds = tt.train(
x,
y,
train_indices=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
test_indices=np.array([8, 9]),fold=0
test_indices=np.array([8, 9]), fold=0
)
preds = tt.predict(x)
assert len(train_preds[0]) == 2
Expand All @@ -379,14 +381,14 @@ def test_predict_none(self):
optimizer=self.optimizer,
to_predict="none",
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
train_preds = tt.train(
x,
y,
train_indices=np.array([0, 1, 2, 3, 4, 5, 6, 7]),
test_indices=np.array([8, 9]),fold=0
test_indices=np.array([8, 9]), fold=0
)
preds = tt.predict(x)
assert len(train_preds[0]) == 10
Expand All @@ -411,7 +413,7 @@ def test_train_with_scheduler(self):
optimizer=self.optimizer,
scheduler=self.scheduler,
)
tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand All @@ -427,7 +429,7 @@ def test_train_one_gpu(self):
optimizer=self.optimizer,
)

tt.update_model_directory("tests/cache/tm")
tt.update_model_directory(self.model_path / "tm")
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand All @@ -442,7 +444,7 @@ def test_train_one_gpu_saved(self):
optimizer=self.optimizer,
)

tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand All @@ -462,7 +464,7 @@ def test_train_two_gpu_saved(self):
optimizer=self.optimizer,
)

tt.update_model_directory("tests/cache")
tt.update_model_directory(self.model_path)
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9])
Expand Down

0 comments on commit 36ffd82

Please sign in to comment.