Skip to content

Commit

Permalink
Weighted mean loss per epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
burggraaff committed Apr 26, 2024
1 parent 5e998be commit 79135df
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
29 changes: 26 additions & 3 deletions fpcup/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
default_outline = {"color": "black", "linewidth": 0.5}


### PLOTTING FUNCTIONS
### GEOSPATIAL PLOTS
def _configure_map_panels(axs: plt.Axes | Iterable[plt.Axes],
province: Province | Iterable[Province]=NETHERLANDS, **kwargs) -> None:
"""
Expand Down Expand Up @@ -477,6 +477,29 @@ def plot_wofost_summary(summary: Summary, keys: Iterable[str]=KEYS_AGGREGATE_PLO
plot_wofost_summary_byprovince = partial(wofost_summary_geo, rasterized=True, province=provinces.values(), use_coarse=True)


def weighted_mean_loss(loss_per_batch: np.ndarray) -> np.ndarray:
"""
Return the weighted mean loss per epoch, weighted with a sawtooth.
"""
# Check dimensionality
INPUT_IS_1D = (loss_per_batch.ndim == 1)
if INPUT_IS_1D:
loss_per_batch = loss_per_batch[np.newaxis, :]

# Generate sawtooth
n_batches = loss_per_batch.shape[1]
sawtooth = np.arange(n_batches) + 1

# Weighted mean
loss_per_epoch = np.average(loss_per_batch, weights=sawtooth, axis=1)

# Return 1D if the input was 1D
if INPUT_IS_1D:
loss_per_epoch = loss_per_epoch[0]

return loss_per_epoch


def plot_loss_curve(losses_train: np.ndarray, *, losses_test: Optional[np.ndarray]=None,
title: Optional[str]=None, saveto: Optional[PathOrStr]=None) -> None:
"""
Expand All @@ -493,14 +516,14 @@ def plot_loss_curve(losses_train: np.ndarray, *, losses_test: Optional[np.ndarra

# Pull out data
loss_initial = [losses_train[0, 0]]
losses_train_epoch = losses_train[:, -1]
losses_train_epoch = weighted_mean_loss(losses_train)
losses_train_epoch = np.concatenate([loss_initial, losses_train_epoch])

losses_train_batch = losses_train.ravel()

# Variables for limits etc.
try:
maxloss = max(losses_train.max(), losses_test.max())
maxloss = np.nanmax(losses_train.max(), losses_test.max())
except AttributeError: # is no test losses were provided
maxloss = losses_train.max()

Expand Down
8 changes: 8 additions & 0 deletions nn/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def train_batch(model: nn.Module, loss_function: Callable, optimizer: torch.opti
return loss.item()


def test_batch(model: nn.Module, loss_function: Callable, optimizer: torch.optim.Optimizer, X: Tensor, y: Tensor) -> float:
"""
Test a given neural network `model` on data.
One batch.
"""
pass


def train_epoch(model: nn.Module, dataloader: DataLoader, loss_function: Callable, optimizer: torch.optim.Optimizer) -> list[float]:
"""
Train a given neural network `model` on data.
Expand Down

0 comments on commit 79135df

Please sign in to comment.