diff --git a/sematic/examples/mnist/pytorch/train_eval.py b/sematic/examples/mnist/pytorch/train_eval.py index 0c453228e..d5a856b32 100644 --- a/sematic/examples/mnist/pytorch/train_eval.py +++ b/sematic/examples/mnist/pytorch/train_eval.py @@ -109,7 +109,7 @@ def test(model: nn.Module, device: torch.device, test_loader: DataLoader): correct += pred.eq(target).sum().item() test_loss /= len(test_loader.dataset) # type: ignore - pr_curve = PrecisionRecallCurve(num_classes=10) + pr_curve = PrecisionRecallCurve(num_classes=10, task="multiclass") precision, recall, thresholds = pr_curve(torch.cat(probas), torch.cat(targets)) classes = [] for i in range(10):