diff --git a/minerva/models/nets/setr.py b/minerva/models/nets/setr.py index 7a91827..4d83130 100644 --- a/minerva/models/nets/setr.py +++ b/minerva/models/nets/setr.py @@ -594,8 +594,9 @@ def _single_step(self, batch: torch.Tensor, batch_idx: int, step_name: str): metrics = self._compute_metrics(y_hat[0], y, step_name) for metric_name, metric_value in metrics.items(): - self.log_dict( - {metric_name: metric_value}, + self.log( + metric_name, + metric_value, on_step=False, on_epoch=True, prog_bar=True,