Skip to content

Commit

Permalink
fix summation checking in contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstiNik committed Apr 22, 2024
1 parent 60f5a73 commit e854df7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion znnl/loss_functions/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __call__(self, inputs: np.ndarray, targets: np.ndarray) -> float:

losses = self.compute_losses(inputs, targets)

if len(losses) > 1:
if isinstance(losses, tuple):
return np.array([losses]).sum()
else:
return losses

0 comments on commit e854df7

Please sign in to comment.