Skip to content

Commit

Permalink
Implement safe ROC at inference
Browse files Browse the repository at this point in the history
  • Loading branch information
KristinaUlicna committed Oct 19, 2023
1 parent 76c64c7 commit e7c9981
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions grace/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
roc_auc_score,
average_precision_score,
)

from grace.styling import LOGGER
from grace.base import GraphAttrs, Annotation, Prediction
from grace.models.datasets import dataset_from_graph
from grace.evaluation.metrics_classifier import safe_roc_auc_score
from grace.visualisation.plotting import (
plot_confusion_matrix_tiles,
plot_areas_under_curves,
Expand Down Expand Up @@ -359,11 +359,11 @@ def calculate_numerical_results_on_entire_batch(
inference_batch_metrics["Batch F1-score (edges)"] = prf1_edges[2]

# AUC scores:
inference_batch_metrics["Batch AUROC (nodes)"] = roc_auc_score(
inference_batch_metrics["Batch AUROC (nodes)"] = safe_roc_auc_score(
y_true=predictions_data["n_true"],
y_score=predictions_data["n_prob"],
)
inference_batch_metrics["Batch AUROC (edges)"] = roc_auc_score(
inference_batch_metrics["Batch AUROC (edges)"] = safe_roc_auc_score(
y_true=predictions_data["e_true"],
y_score=predictions_data["e_prob"],
)
Expand Down
12 changes: 6 additions & 6 deletions grace/visualisation/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from grace.base import GraphAttrs, Annotation
from grace.styling import COLORMAPS
from grace.base import GraphAttrs, Annotation
from grace.evaluation.metrics_classifier import safe_roc_auc_score

import matplotlib.pyplot as plt
import networkx as nx
Expand All @@ -10,11 +11,10 @@

from skimage.util import montage
from sklearn.metrics import (
ConfusionMatrixDisplay,
roc_auc_score,
RocCurveDisplay,
average_precision_score,
ConfusionMatrixDisplay,
PrecisionRecallDisplay,
RocCurveDisplay,
)


Expand Down Expand Up @@ -179,7 +179,7 @@ def plot_areas_under_curves(
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize)

# Area under ROC:
roc_score_nodes = roc_auc_score(y_true=node_true, y_score=node_pred)
roc_score_nodes = safe_roc_auc_score(y_true=node_true, y_score=node_pred)
RocCurveDisplay.from_predictions(
y_true=node_true,
y_pred=node_pred,
Expand All @@ -189,7 +189,7 @@ def plot_areas_under_curves(
ax=axes[0],
)

roc_score_edges = roc_auc_score(y_true=edge_true, y_score=edge_pred)
roc_score_edges = safe_roc_auc_score(y_true=edge_true, y_score=edge_pred)
RocCurveDisplay.from_predictions(
y_true=edge_true,
y_pred=edge_pred,
Expand Down

0 comments on commit e7c9981

Please sign in to comment.