From 10b6c51f55da57d3646810d1566a5469961eb4ab Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Thu, 16 Feb 2023 10:29:18 +0000 Subject: [PATCH] style: linting --- .../model_eval/base_artifacts/plots/precision_recall.py | 6 ++++-- tests/model_evaluation/test_visualizations.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/precision_recall.py b/src/psycop_model_training/model_eval/base_artifacts/plots/precision_recall.py index 5250291e..efb56a90 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/precision_recall.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/precision_recall.py @@ -28,11 +28,13 @@ def plot_precision_recall( Union[None, Path]: None if save_path is None, else path to saved figure. """ precision, recall, _ = precision_recall_curve( - y_true=eval_dataset.y, probas_pred=eval_dataset.y_hat_probs + y_true=eval_dataset.y, + probas_pred=eval_dataset.y_hat_probs, ) auprc = average_precision_score( - y_true=eval_dataset.y, y_score=eval_dataset.y_hat_probs + y_true=eval_dataset.y, + y_score=eval_dataset.y_hat_probs, ) legend_label = "AUPRC = " diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index be0260d1..cd5fa2ce 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -186,5 +186,6 @@ def test_plot_roc_auc(synth_eval_dataset: EvalDataset): def test_plot_precision_recall(synth_eval_dataset: EvalDataset): plot_precision_recall( - eval_dataset=synth_eval_dataset, save_path="tmp/test_plot_precision_recall.png" + eval_dataset=synth_eval_dataset, + save_path="tmp/test_plot_precision_recall.png", )