diff --git a/znnl/loss_functions/simple_loss.py b/znnl/loss_functions/simple_loss.py index f574bee..40ec2a6 100644 --- a/znnl/loss_functions/simple_loss.py +++ b/znnl/loss_functions/simple_loss.py @@ -73,6 +73,6 @@ def __call__( total loss of all points based on the similarity measurement """ if mask is not None: - return np.sum(self.metric(point_1, point_2) * mask, axis=0) + return np.mean(self.metric(point_1, point_2) * mask, axis=0) else: - return np.sum(self.metric(point_1, point_2), axis=0) + return np.mean(self.metric(point_1, point_2), axis=0)