Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
style: linting
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored and bokajgd committed Feb 16, 2023
1 parent 8e92d94 commit 10b6c51
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ def plot_precision_recall(
Union[None, Path]: None if save_path is None, else path to saved figure.
"""
precision, recall, _ = precision_recall_curve(
y_true=eval_dataset.y, probas_pred=eval_dataset.y_hat_probs
y_true=eval_dataset.y,
probas_pred=eval_dataset.y_hat_probs,
)

auprc = average_precision_score(
y_true=eval_dataset.y, y_score=eval_dataset.y_hat_probs
y_true=eval_dataset.y,
y_score=eval_dataset.y_hat_probs,
)

legend_label = "AUPRC = "
Expand Down
3 changes: 2 additions & 1 deletion tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,6 @@ def test_plot_roc_auc(synth_eval_dataset: EvalDataset):

def test_plot_precision_recall(synth_eval_dataset: EvalDataset):
plot_precision_recall(
eval_dataset=synth_eval_dataset, save_path="tmp/test_plot_precision_recall.png"
eval_dataset=synth_eval_dataset,
save_path="tmp/test_plot_precision_recall.png",
)

0 comments on commit 10b6c51

Please sign in to comment.