From 11512f99f0843f1ab9bc6fe3d4c0c9b226b65cfd Mon Sep 17 00:00:00 2001 From: Sophia Reiner Date: Fri, 2 Aug 2024 09:09:01 -0600 Subject: [PATCH] fix call to calc_uncertainty --- ptype/callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ptype/callbacks.py b/ptype/callbacks.py index cefd047..72bb318 100644 --- a/ptype/callbacks.py +++ b/ptype/callbacks.py @@ -5,7 +5,6 @@ CSVLogger, EarlyStopping, ) -from mlguess.keras.models import calc_prob_uncertainty from tensorflow.python.keras.callbacks import ReduceLROnPlateau from sklearn.metrics import precision_recall_fscore_support, roc_auc_score from hagelslag.evaluation.ProbabilityMetrics import DistributedROC @@ -68,7 +67,7 @@ def __init__(self, x, y, name="val", n_bins = 10, use_uncertainty = False, **kwa def on_epoch_end(self, epoch, logs={}): pred_probs = np.asarray(self.model.predict(self.x)) if self.use_uncertainty: - pred_probs, _, _, _ = calc_prob_uncertainty(pred_probs) + pred_probs, _, _, _ = self.model.calc_uncertainty(pred_probs) pred_probs = pred_probs.numpy() logs[f"{self.name}_csi"] = self.mean_csi(pred_probs) true_labels = np.argmax(self.y, 1)