Skip to content

Commit

Permalink
Change torch trainer to numpy and rename n_folds
Browse files Browse the repository at this point in the history
  • Loading branch information
hjdeheer committed Apr 16, 2024
1 parent 37d9cfb commit b03404b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 25 deletions.
41 changes: 22 additions & 19 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class TorchTrainer(TrainingBlock):
- `batch_size` (int): Batch size
- `patience` (int): Patience for early stopping
- `test_size` (float): Relative size of the test set
- `n_folds` (float): Number of folds for cross validation (0 for train full,
- `fold` (int): Fold number
Methods:
.. code-block:: python
Expand Down Expand Up @@ -125,14 +127,14 @@ def log_to_terminal(self, message: str) -> None:
test_size: Annotated[float, Interval(ge=0, le=1)] = 0.2 # Hashing purposes

_fold: int = field(default=-1, init=False, repr=False, compare=False)
test_split_type: float = field(default=-1, init=True, repr=False, compare=False)
n_folds: float = field(default=-1, init=True, repr=False, compare=False)

def __post_init__(self) -> None:
"""Post init method for the TorchTrainer class."""

if self.test_split_type == -1:
if self.n_folds == -1:
raise ValueError(
"Train_split_type needs to be set to either test_size or n_folds"
"Please specify the number of folds for cross validation or set n_folds to 0 for train full."
)

self.save_model_to_disk = True
Expand Down Expand Up @@ -171,7 +173,7 @@ def custom_train(
x: npt.NDArray[np.float32],
y: npt.NDArray[np.float32],
**train_args: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""Train the model.
:param x: The input to the system.
Expand Down Expand Up @@ -218,7 +220,7 @@ def custom_train(
)
self._load_model()
# Return the predictions
return self.predict_on_loader(pred_dataloader), torch.tensor(y)
return self.predict_on_loader(pred_dataloader), y

self.log_to_terminal(f"Training model: {self.model.__class__.__name__}")
self.log_to_debug(f"Training model: {self.model.__class__.__name__}")
Expand Down Expand Up @@ -253,11 +255,11 @@ def custom_train(
if save_model:
self._save_model()

return self.predict_on_loader(pred_dataloader), torch.tensor(y)
return self.predict_on_loader(pred_dataloader), y

def custom_predict(
self, x: npt.NDArray[np.float32], **pred_args: Any
) -> torch.Tensor:
) -> npt.NDArray[np.float32]:
"""Predict on the test data
:param x: The input to the system.
Expand All @@ -277,26 +279,27 @@ def custom_predict(
pred_dataset, batch_size=curr_batch_size, shuffle=False
)

# Predict with a single model, test_split_type lower than 1 means a single test size, no CV
if self.test_split_type < 1 or pred_args.get("use_single_model", False):
# Predict with a single model, n_folds lower than 1 means a single test size, no CV
if self.n_folds < 1 or pred_args.get("use_single_model", False):
self._load_model()
return self.predict_on_loader(pred_dataloader)

# Ensemble the fold models:
predictions = []
for i in range(int(self.test_split_type)):
self.log_to_terminal(
f"Predicting with model fold {i + 1}/{self.test_split_type}"
)
for i in range(int(self.n_folds)):
self.log_to_terminal(f"Predicting with model fold {i + 1}/{self.n_folds}")
self._fold = i # set the fold, which updates the hash
self._load_model() # load the model for this fold
predictions.append(self.predict_on_loader(pred_dataloader))

test_predictions = torch.stack(predictions)
# Average the predictions using numpy
test_predictions = np.array(predictions)

return torch.mean(test_predictions, dim=0)
return np.mean(test_predictions, axis=0)

def predict_on_loader(self, loader: DataLoader[tuple[Tensor, ...]]) -> torch.Tensor:
def predict_on_loader(
self, loader: DataLoader[tuple[Tensor, ...]]
) -> npt.NDArray[np.float32]:
"""Predict on the loader.
:param loader: The loader to predict on.
Expand All @@ -309,11 +312,11 @@ def predict_on_loader(self, loader: DataLoader[tuple[Tensor, ...]]) -> torch.Ten
for data in tepoch:
X_batch = data[0].to(self.device).float()

y_pred = self.model(X_batch).cpu()
y_pred = self.model(X_batch).cpu().numpy()
predictions.extend(y_pred)

self.log_to_terminal("Done predicting")
return torch.stack(predictions)
return np.array(predictions)

def get_hash(self) -> str:
"""Get the hash of the block.
Expand All @@ -322,7 +325,7 @@ def get_hash(self) -> str:
:return: The hash of the block.
"""
result = f"{self._hash}_{self.test_split_type}"
result = f"{self._hash}_{self.n_folds}"
if self._fold != -1:
result += f"_f{self._fold}"
return result
Expand Down
52 changes: 46 additions & 6 deletions tests/pipeline/model/training/test_torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestTorchTrainer:

class ImplementedTorchTrainer(TorchTrainer):
def __post_init__(self):
self.test_split_type = 1
self.n_folds = 1
super().__post_init__()

def log_to_terminal(self, message: str) -> None:
Expand All @@ -28,7 +28,7 @@ def log_to_debug(self, message: str) -> None:
@dataclass
class FullyImplementedTorchTrainer(TorchTrainer):
def __post_init__(self):
self.test_split_type = 1
self.n_folds = 1
super().__post_init__()

def log_to_terminal(self, message: str) -> None:
Expand All @@ -48,7 +48,7 @@ def log_to_warning(self, message: str) -> None:

def test_init_no_args(self):
with pytest.raises(TypeError):
TorchTrainer(test_split_type=1)
TorchTrainer(n_folds=1)

def test_init_none_args(self):
with pytest.raises(TypeError):
Expand All @@ -57,7 +57,7 @@ def test_init_none_args(self):
criterion=None,
optimizer=None,
device=None,
test_split_type=1,
n_folds=1,
)

def test_init_proper_args(self):
Expand All @@ -66,7 +66,7 @@ def test_init_proper_args(self):
model=self.simple_model,
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
test_split_type=0,
n_folds=0,
)

def test_init_proper_args_with_implemented(self):
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_train_full(self):
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
tt.test_split_type = 0
tt.n_folds = 0
tt.update_model_directory("tests/cache")
x = torch.rand(10, 1)
y = torch.rand(10)
Expand Down Expand Up @@ -245,6 +245,46 @@ def test_predict(self):

remove_cache_files()

def test_predict_3fold(self):
tt = self.FullyImplementedTorchTrainer(
model=self.simple_model,
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
remove_cache_files()
tt.n_folds = 3
tt.update_model_directory("tests/cache")
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(
x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9], fold=0
)
tt.train(
x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9], fold=1
)
tt.train(
x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[8, 9], fold=2
)
tt.predict(x)

remove_cache_files()

def test_predict_train_full(self):
tt = self.FullyImplementedTorchTrainer(
model=self.simple_model,
criterion=torch.nn.MSELoss(),
optimizer=self.optimizer,
)
remove_cache_files()
tt.n_folds = 0
tt.update_model_directory("tests/cache")
x = torch.rand(10, 1)
y = torch.rand(10)
tt.train(x, y, train_indices=[0, 1, 2, 3, 4, 5, 6, 7], test_indices=[])
tt.predict(x)

remove_cache_files()

def test_predict_no_model_trained(self):
tt = self.FullyImplementedTorchTrainer(
model=self.simple_model,
Expand Down

0 comments on commit b03404b

Please sign in to comment.