diff --git a/src/psycopt2d/tables/performance_by_threshold.py b/src/psycopt2d/tables/performance_by_threshold.py index e0996b55..205512af 100644 --- a/src/psycopt2d/tables/performance_by_threshold.py +++ b/src/psycopt2d/tables/performance_by_threshold.py @@ -1,5 +1,5 @@ """Get performance by which threshold is used to classify positive.""" -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import Optional, Union import numpy as np @@ -10,17 +10,47 @@ from psycopt2d.evaluation_dataclasses import EvalDataset +def get_true_positives( + eval_dataset: EvalDataset, + positive_rate_threshold: Optional[float] = 0.5, +): + """Get dataframe containing only true positives. + + Args: + eval_dataset (EvalDataset): EvalDataset object. + positive_rate_threshold (float, optional): Threshold above which patients are classified as positive. Defaults to 0.5. + + Returns: + pd.DataFrame: Dataframe containing only true positives. + """ + + # Generate df + df = pd.DataFrame( + { + "id": eval_dataset.ids, + "pred_probs": eval_dataset.y_hat_probs, + "pred_timestamps": eval_dataset.pred_timestamps, + "outcome_timestamps": eval_dataset.outcome_timestamps, + }, + ) + + # Keep only true positives + df["true_positive"] = (df["pred_probs"] >= positive_rate_threshold) & ( + df["outcome_timestamps"].notnull() + ) + + return df[df["true_positive"]] + + def performance_by_threshold( # pylint: disable=too-many-locals - labels: Sequence[int], - pred_probs: Sequence[float], + eval_dataset: EvalDataset, positive_threshold: float, round_to: int = 4, ) -> pd.DataFrame: """Generates a row for a performance_by_threshold table. Args: - labels (Iterable[int]): True labels. - pred_probs (Iterable[float]): Model prediction probabilities. + eval_dataset (EvalDataset): EvalDataset object. positive_threshold (float): Threshold for a probability to be labelled as "positive". round_to (int): Number of decimal places to round metrics @@ -28,9 +58,9 @@ def performance_by_threshold( # pylint: disable=too-many-locals Returns: pd.DataFrame """ - preds = np.where(pred_probs > positive_threshold, 1, 0) # type: ignore + preds = np.where(eval_dataset.y_hat_probs > positive_threshold, 1, 0) # type: ignore - conf_matrix = confusion_matrix(labels, preds) + conf_matrix = confusion_matrix(eval_dataset.y, preds) true_neg = conf_matrix[0][0] false_neg = conf_matrix[1][0] @@ -79,10 +109,7 @@ def performance_by_threshold( # pylint: disable=too-many-locals def days_from_first_positive_to_diagnosis( - ids: Iterable[Union[float, str]], - pred_probs: Iterable[Union[float, str]], - pred_timestamps: Iterable[pd.Timestamp], - outcome_timestamps: Iterable[pd.Timestamp], + eval_dataset: EvalDataset, positive_rate_threshold: Optional[float] = 0.5, aggregation_method: Optional[str] = "sum", ) -> float: @@ -90,31 +117,18 @@ def days_from_first_positive_to_diagnosis( patient's outcome timestamp. Args: - ids (Iterable[Union[float, str]]): Iterable of patient IDs. - pred_probs (Iterable[Union[float, str]]): Predicted probabilities. - pred_timestamps (Iterable[pd.Timestamp]): Timestamps for each prediction time. - outcome_timestamps (Iterable[pd.Timestamp]): Timestamps of patient outcome. + eval_dataset (EvalDataset): EvalDataset object. positive_rate_threshold (float, optional): Threshold above which patients are classified as positive. Defaults to 0.5. aggregation_method (str, optional): How to aggregate the warning days. Defaults to "sum". Returns: - float: _description_ + float: Total number of days from first positive prediction to outcome. """ - # Generate df - df = pd.DataFrame( - { - "id": ids, - "pred_probs": pred_probs, - "pred_timestamps": pred_timestamps, - "outcome_timestamps": outcome_timestamps, - }, - ) - - # Keep only true positives - df["true_positive"] = (df["pred_probs"] >= positive_rate_threshold) & ( - df["outcome_timestamps"].notnull() + # Generate df with only true positives + df = get_true_positives( + eval_dataset=eval_dataset, + positive_rate_threshold=positive_rate_threshold, ) - df = df[df["true_positive"]] # Find timestamp of first positive prediction df["timestamp_first_pos_pred"] = df.groupby("id")["pred_timestamps"].transform( @@ -146,6 +160,29 @@ def days_from_first_positive_to_diagnosis( return df["warning_days"].agg(aggregation_method) +def prop_with_at_least_one_true_positve( + eval_dataset: EvalDataset, + positive_rate_threshold: Optional[float] = 0.5, +) -> float: + """Get proportion of patients with at least one true positive prediction. + + Args: + eval_dataset (EvalDataset): EvalDataset object. + positive_rate_threshold (float, optional): Threshold above which patients are classified as positive. Defaults to 0.5. + + Returns: + float: Proportion of thresholds with at least one true positive. + """ + # Generate df with only true positives + df = get_true_positives( + eval_dataset=eval_dataset, + positive_rate_threshold=positive_rate_threshold, + ) + + # Return number of unique patients with at least one true positive + return round(df["id"].nunique() / len(set(eval_dataset.ids)), 4) + + def generate_performance_by_positive_rate_table( eval_dataset: EvalDataset, positive_rate_thresholds: Iterable[Union[int, float]], @@ -175,18 +212,14 @@ def generate_performance_by_positive_rate_table( # For each percentile, calculate relevant performance metrics for threshold_value in pred_proba_thresholds: threshold_metrics = performance_by_threshold( - labels=eval_dataset.y, - pred_probs=eval_dataset.y_hat_probs, + eval_dataset=eval_dataset, positive_threshold=threshold_value, ) threshold_metrics[ # pylint: disable=unsupported-assignment-operation "total_warning_days" ] = days_from_first_positive_to_diagnosis( - ids=eval_dataset.ids, - pred_probs=eval_dataset.y_hat_probs, - pred_timestamps=eval_dataset.pred_timestamps, - outcome_timestamps=eval_dataset.outcome_timestamps, + eval_dataset=eval_dataset, positive_rate_threshold=threshold_value, aggregation_method="sum", ) @@ -195,16 +228,20 @@ def generate_performance_by_positive_rate_table( "mean_warning_days" ] = round( days_from_first_positive_to_diagnosis( - ids=eval_dataset.ids, - pred_probs=eval_dataset.y_hat_probs, - pred_timestamps=eval_dataset.pred_timestamps, - outcome_timestamps=eval_dataset.outcome_timestamps, + eval_dataset=eval_dataset, positive_rate_threshold=threshold_value, aggregation_method="mean", ), 0, ) + threshold_metrics[ # pylint: disable=unsupported-assignment-operation + "prop_with_at_least_one_true_positive" + ] = prop_with_at_least_one_true_positve( + eval_dataset=eval_dataset, + positive_rate_threshold=threshold_value, + ) + rows.append(threshold_metrics) df = pd.concat(rows) diff --git a/tests/model_evaluation/test_performance_by_threshold.py b/tests/model_evaluation/test_performance_by_threshold.py index db200c6a..5f5a5afc 100644 --- a/tests/model_evaluation/test_performance_by_threshold.py +++ b/tests/model_evaluation/test_performance_by_threshold.py @@ -38,38 +38,107 @@ def test_generate_performance_by_threshold_table(synth_eval_dataset: EvalDataset expected_df = pd.DataFrame( { - "threshold_percentile": {0: 90.0, 1: 50.0, 2: 10.0}, - "true_prevalence": {0: 0.0502, 1: 0.0502, 2: 0.0502}, - "positive_rate": {0: 0.1, 1: 0.5, 2: 0.5511}, - "negative_rate": {0: 0.9, 1: 0.5, 2: 0.4489}, - "PPV": {0: 0.0508, 1: 0.0502, 2: 0.0502}, - "NPV": {0: 0.9498, 1: 0.9497, 2: 0.9497}, - "sensitivity": {0: 0.1011, 1: 0.4997, 2: 0.5503}, - "specificity": {0: 0.9001, 1: 0.5, 2: 0.4488}, - "FPR": {0: 0.0999, 1: 0.5, 2: 0.5512}, - "FNR": {0: 0.8989, 1: 0.5003, 2: 0.4497}, - "accuracy": {0: 0.8599, 1: 0.5, 2: 0.4539}, - "true_positives": {0: 508, 1: 2510, 2: 2764}, - "true_negatives": {0: 85485, 1: 47487, 2: 42627}, - "false_positives": {0: 9492, 1: 47490, 2: 52350}, - "false_negatives": {0: 4515, 1: 2513, 2: 2259}, - "total_warning_days": {0: 609757.0, 1: 2619787.0, 2: 4612729.0}, - "warning_days_per_false_positive": {0: 64.2, 1: 55.2, 2: 88.1}, - "mean_warning_days": {0: 1252, 1: 1332, 2: 1451}, + "true_prevalence": [ + 0.0502, + 0.0502, + 0.0502, + ], + "positive_rate": [ + 0.5511, + 0.5, + 0.1, + ], + "negative_rate": [ + 0.4489, + 0.5, + 0.9, + ], + "PPV": [ + 0.0502, + 0.0502, + 0.0508, + ], + "NPV": [ + 0.9497, + 0.9497, + 0.9498, + ], + "sensitivity": [ + 0.5503, + 0.4997, + 0.1011, + ], + "specificity": [ + 0.4488, + 0.5, + 0.9001, + ], + "FPR": [ + 0.5512, + 0.5, + 0.0999, + ], + "FNR": [ + 0.4497, + 0.5003, + 0.8989, + ], + "accuracy": [ + 0.4539, + 0.5, + 0.8599, + ], + "true_positives": [ + 2764, + 2510, + 508, + ], + "true_negatives": [ + 42627, + 47487, + 85485, + ], + "false_positives": [ + 52350, + 47490, + 9492, + ], + "false_negatives": [ + 2259, + 2513, + 4515, + ], + "total_warning_days": [ + 4612729.0, + 2619787.0, + 609757.0, + ], + "warning_days_per_false_positive": [ + 88.1, + 55.2, + 64.2, + ], + "mean_warning_days": [ + 1451.0, + 1332.0, + 1252.0, + ], + "prop_with_at_least_one_true_positive": [ + 0.0503, + 0.0311, + 0.0077, + ], }, ) for col in output_table.columns: - output_table[col].equals(expected_df[col]) + assert output_table[col].equals(expected_df[col]) def test_time_from_flag_to_diag(synth_eval_dataset: EvalDataset): # Threshold = 0.5 val = days_from_first_positive_to_diagnosis( - ids=synth_eval_dataset.ids, - pred_probs=synth_eval_dataset.y_hat_probs, - pred_timestamps=synth_eval_dataset.pred_timestamps, - outcome_timestamps=synth_eval_dataset.outcome_timestamps, + eval_dataset=synth_eval_dataset, positive_rate_threshold=0.5, ) @@ -77,10 +146,7 @@ def test_time_from_flag_to_diag(synth_eval_dataset: EvalDataset): # Threshold = 0.2 val = days_from_first_positive_to_diagnosis( - ids=synth_eval_dataset.ids, - pred_probs=synth_eval_dataset.y_hat_probs, - pred_timestamps=synth_eval_dataset.pred_timestamps, - outcome_timestamps=synth_eval_dataset.outcome_timestamps, + eval_dataset=synth_eval_dataset, positive_rate_threshold=0.2, )