Skip to content

Commit

Permalink
Merge pull request #56 from discovery-unicamp/improvements-to-logging
Browse files Browse the repository at this point in the history
Improvements to logging
  • Loading branch information
GabrielBG0 authored May 13, 2024
2 parents c85b6e6 + c341979 commit d800b6c
Showing 1 changed file with 120 additions and 9 deletions.
129 changes: 120 additions & 9 deletions minerva/models/nets/setr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import lightning as L
import torch
from torch import nn
from torchmetrics import JaccardIndex

from minerva.models.nets.vit import _VisionTransformerBackbone
from minerva.utils.upsample import Upsample, resize
Expand Down Expand Up @@ -406,9 +407,15 @@ def __init__(
conv_act: Optional[nn.Module] = None,
interpolate_mode: str = "bilinear",
loss_fn: Optional[nn.Module] = None,
log_train_metrics: bool = False,
train_metrics: Optional[nn.Module] = None,
log_val_metrics: bool = False,
val_metrics: Optional[nn.Module] = None,
log_test_metrics: bool = False,
test_metrics: Optional[nn.Module] = None,
aux_output: bool = True,
aux_output_layers: list[int] | None = [9, 14, 19],
aux_weights: list[float] = [0.4, 0.4, 0.4],
aux_weights: list[float] = [0.3, 0.3, 0.3],
):
"""
Initializes the SetR model.
Expand Down Expand Up @@ -453,6 +460,10 @@ def __init__(
The interpolation mode for upsampling in the decoder. Defaults to "bilinear".
loss_fn : nn.Module, optional
The loss function to be used during training. Defaults to None.
log_metrics : bool
Whether to log metrics during training. Defaults to True.
metrics : list[MetricTypeSetR], optional
The metrics to be used for evaluation. Defaults to [MetricTypeSetR.mIoU, MetricTypeSetR.mIoU, MetricTypeSetR.mIoU].
"""
super().__init__()
Expand All @@ -475,6 +486,28 @@ def __init__(
self.num_classes = num_classes
self.aux_weights = aux_weights

self.log_train_metrics = log_train_metrics
self.log_val_metrics = log_val_metrics
self.log_test_metrics = log_test_metrics

if log_train_metrics:
assert (
train_metrics is not None
), "train_metrics must be provided if log_train_metrics is True"
self.train_metrics = train_metrics

if log_val_metrics:
assert (
val_metrics is not None
), "val_metrics must be provided if log_val_metrics is True"
self.val_metrics = val_metrics

if log_test_metrics:
assert (
test_metrics is not None
), "test_metrics must be provided if log_test_metrics is True"
self.test_metrics = test_metrics

self.model = _SetR_PUP(
image_size=image_size,
patch_size=patch_size,
Expand All @@ -498,6 +531,15 @@ def __init__(
aux_output_layers=aux_output_layers,
)

self.train_step_outputs = []
self.train_step_labels = []

self.val_step_outputs = []
self.val_step_labels = []

self.test_step_outputs = []
self.test_step_labels = []

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)

Expand Down Expand Up @@ -534,13 +576,10 @@ def _loss_func(
+ (loss_aux2 * self.aux_weights[1])
+ (loss_aux3 * self.aux_weights[2])
)

loss = self.loss_fn(y_hat, y.long())
return loss

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str):
"""Perform a single step of the training/validation loop.
Parameters
Expand All @@ -558,18 +597,90 @@ def _single_step(
The loss value.
"""
x, y = batch
y_hat = self.model(x)
y_hat = self.model(x.float())
loss = self._loss_func(y_hat[0], y.squeeze(1))
self.log(
f"{step_name}_loss",
loss,

if step_name == "train":
self.train_step_outputs.append(y_hat[0])
self.train_step_labels.append(y)
elif step_name == "val":
self.val_step_outputs.append(y_hat[0])
self.val_step_labels.append(y)
elif step_name == "test":
self.test_step_outputs.append(y_hat[0])
self.test_step_labels.append(y)

self.log_dict(
{
f"{step_name}_loss": loss,
},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)

return loss

def on_train_epoch_end(self):
if self.log_train_metrics:
y_hat = torch.cat(self.train_step_outputs)
y = torch.cat(self.train_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.train_metrics(preds, y)
mIoU = self.train_metrics.compute()

self.log_dict(
{
f"train_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.train_step_outputs.clear()
self.train_step_labels.clear()

def on_validation_epoch_end(self):
if self.log_val_metrics:
y_hat = torch.cat(self.val_step_outputs)
y = torch.cat(self.val_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.val_metrics(preds, y)
mIoU = self.val_metrics.compute()

self.log_dict(
{
f"val_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.val_step_outputs.clear()
self.val_step_labels.clear()

def on_test_epoch_end(self):
if self.log_test_metrics:
y_hat = torch.cat(self.test_step_outputs)
y = torch.cat(self.test_step_labels)
preds = torch.argmax(y_hat, dim=1, keepdim=True)
self.test_metrics(preds, y)
mIoU = self.test_metrics.compute()
self.log_dict(
{
f"test_metrics": mIoU,
},
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
self.test_step_outputs.clear()
self.test_step_labels.clear()

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, "train")

Expand Down

0 comments on commit d800b6c

Please sign in to comment.