From df985b5e12d2e5c3f8b9f22c04cdfe1ca793412b Mon Sep 17 00:00:00 2001 From: jaamarks Date: Tue, 8 Oct 2024 09:51:39 -0400 Subject: [PATCH] Refactored ``_retain_valid_discordant_replicates`` - Improved data type handling - Updated corresponding tests in ``test_sample_qc_table.py`` Overall logic of retaining valid discordant replicates remains the same, but this code fixes some dtype issues with the last implementation. --- .../workflow/scripts/sample_qc_table.py | 87 ++++++------ .../workflow/scripts/test_sample_qc_table.py | 131 +++++++++++++++--- 2 files changed, 155 insertions(+), 63 deletions(-) diff --git a/src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py b/src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py index aae91755..a877933b 100755 --- a/src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py +++ b/src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py @@ -417,7 +417,7 @@ def _read_contam(file_name: Optional[Path], Sample_IDs: pd.Index) -> pd.DataFram return pd.DataFrame( index=Sample_IDs, columns=["Contamination_Rate", "is_contaminated"], - ).astype({"Contamination_Rate": pd.NA, "is_contaminated": pd.NA}) + ).astype({"Contamination_Rate": "float", "is_contaminated": "boolean"}) return ( agg_contamination.read(file_name) @@ -464,12 +464,12 @@ def add_qc_columns( remove_rep_discordant: bool, ) -> pd.DataFrame: add_call_rate_flags(sample_qc) - _add_identifiler(sample_qc) _add_analytic_exclusion( sample_qc, remove_contam, remove_rep_discordant, ) + _add_identifiler(sample_qc) _add_subject_representative(sample_qc) _add_subject_dropped_from_study(sample_qc) @@ -516,64 +516,53 @@ def reason_string(row: pd.Series) -> str: def _retain_valid_discordant_replicates( sample_qc: pd.DataFrame, ) -> pd.DataFrame: - """Check and update the status of a pair of samples labeled as - discordant expected replicates. - - This function verifies if the provided sample pair is labeled as - discordant expected replicates. - If they are, it checks for contamination or low call rate flags on - each sample. - If one of the samples is found to be contaminated or has a low call - rate, the function retains the non-contaminated and non-low-call-rate - sample, updating its status to remove the expected replicate label. + """Updates the status of discordant expected replicates. + + Given a pair of discordant expected replicates, it checks for contamination + or low call rate flags on each sample. If one sample is found to be + contaminated or has a low call rate, the function updates the + "is_discordant_replicate" status of the non-contaminated and + non-low-call-rate sample to "False". This way, it can be retained for + subject-level analysis. """ - # Assuming sample_qc is your DataFrame - if "Sample_ID" in sample_qc.columns: - sample_qc = sample_qc.set_index("Sample_ID") - - # Iterate through each sample in the DataFrame for index, row in sample_qc.iterrows(): - # Check if the sample is a discordant expected replicate if row["is_discordant_replicate"]: - # Get the list of other samples it is discordant with - discordant_samples = row["replicate_ids"].split("|") - # Initialize flags for contamination and call rate issues - is_current_sample_low_call_rate = row["is_cr2_filtered"] + is_current_sample_low_call_rate = row["is_call_rate_filtered"] is_current_sample_contaminated = ( - row["is_contaminated"] if not pd.isna(row["is_contaminated"]) else False + # Treat pd.NA as not contaminated + row["is_contaminated"] + if not pd.isna(row["is_contaminated"]) + else False ) - # Initialize a flag to track if all discordant samples have issues + # flag to track if all discordant samples have issues all_other_samples_issue = True + discordant_samples = row["replicate_ids"].split("|") - # Check each discordant sample for sample_id in discordant_samples: - if sample_id == row.name: - continue + if sample_id == row["Sample_ID"]: + continue # only look at other samples else: - # Get the row for the discordant sample - discordant_row = sample_qc.loc[sample_id] - if not discordant_row.empty: - contaminated = ( - discordant_row["is_contaminated"] - if not pd.isna(discordant_row["is_contaminated"]) - else False - ) - low_call_rate = discordant_row["is_cr2_filtered"] - - # Check if the discordant sample is contaminated or has a low call rate - if not contaminated and not low_call_rate: - all_other_samples_issue = False - break - - # If the current sample is not contaminated or low call rate + discordant_row = sample_qc[sample_qc["Sample_ID"] == sample_id].iloc[0] + low_call_rate = discordant_row["is_call_rate_filtered"] + contaminated = ( + discordant_row["is_contaminated"] + if not pd.isna(discordant_row["is_contaminated"]) + else False + ) + + # Check if the discordant sample is contaminated and/or has a low call rate + if not contaminated and not low_call_rate: + all_other_samples_issue = False + break + + # Retain the current sample if not contaminated nor low call rate... if not is_current_sample_contaminated and not is_current_sample_low_call_rate: - # If all other samples have issues, update the current sample's status + # and the other samples to have issues if all_other_samples_issue: sample_qc.at[index, "is_discordant_replicate"] = False - return sample_qc @@ -596,7 +585,7 @@ def _add_analytic_exclusion( "is_cr2_filtered": "Call Rate 2 Filtered", } - _retain_valid_discordant_replicates(sample_qc) + sample_qc = _retain_valid_discordant_replicates(sample_qc) if remove_contam: exclusion_criteria["is_contaminated"] = "Contamination" @@ -604,14 +593,18 @@ def _add_analytic_exclusion( if remove_rep_discordant: exclusion_criteria["is_discordant_replicate"] = "Replicate Discordance" + # adding this new column which is a boolean. Checks for any T in a series (e.g., {F, F, F, F, T}.any()) sample_qc["analytic_exclusion"] = sample_qc.reindex(exclusion_criteria.keys(), axis=1).any( axis=1 ) + + # looking across the colums and adding up the Trues sample_qc["num_analytic_exclusion"] = ( sample_qc.reindex(exclusion_criteria.keys(), axis=1).sum(axis=1).astype(int) ) - sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria) + # get the names of the columns that are True and return a "|" delimited string of dict values whose keys were true. + sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria) return sample_qc diff --git a/tests/workflow/scripts/test_sample_qc_table.py b/tests/workflow/scripts/test_sample_qc_table.py index 880680d5..d43ad401 100644 --- a/tests/workflow/scripts/test_sample_qc_table.py +++ b/tests/workflow/scripts/test_sample_qc_table.py @@ -254,22 +254,95 @@ def fake_sample_qc() -> pd.DataFrame: "Call_Rate_2", "is_cr1_filtered", "is_cr2_filtered", + "is_call_rate_filtered", "is_contaminated", "is_discordant_replicate", ] data = [ - ("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False), - ("SP00002", "SB00002", "", False, False, 0.82, False, True, False, False), - ("SP00003", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, True, True), - ("SP00004", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, False, True), - ("SP00005", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True), - ("SP00006", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True), - ("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False), - ("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False), - ("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False), - ("SP00010", "SB00008", "SP00010|SP00011", False, False, 0.99, False, False, False, False), - ("SP00011", "SB00008", "SP00010|SP00011", False, False, 0.94, False, False, False, False), + ("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False, False), + ("SP00002", "SB00002", "", False, False, 0.82, False, True, True, False, False), + ( + "SP00003", + "SB00003", + "SP00003|SP00004", + False, + False, + 0.99, + False, + False, + False, + True, + True, + ), + ( + "SP00004", + "SB00003", + "SP00003|SP00004", + False, + False, + 0.99, + False, + False, + False, + False, + True, + ), + ( + "SP00005", + "SB00004", + "SP00005|SP00006", + False, + False, + 0.99, + False, + False, + False, + False, + True, + ), + ( + "SP00006", + "SB00004", + "SP00005|SP00006", + False, + False, + 0.99, + False, + False, + False, + False, + True, + ), + ("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False, False), + ("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False, False), + ("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False, False), + ( + "SP00010", + "SB00008", + "SP00010|SP00011", + False, + False, + 0.99, + False, + False, + False, + False, + False, + ), + ( + "SP00011", + "SB00008", + "SP00010|SP00011", + False, + False, + 0.94, + False, + False, + False, + False, + False, + ), ( "SP00012", "SB00009", @@ -281,6 +354,7 @@ def fake_sample_qc() -> pd.DataFrame: False, False, False, + False, ), ( "SP00013", @@ -293,6 +367,7 @@ def fake_sample_qc() -> pd.DataFrame: False, False, False, + False, ), ( "SP00014", @@ -305,16 +380,41 @@ def fake_sample_qc() -> pd.DataFrame: False, False, False, + False, + ), + ( + "SP00015", + "SB00010", + "SP00015|SP00016", + False, + False, + 0.99, + False, + True, + True, + False, + True, + ), + ( + "SP00016", + "SB00010", + "SP00015|SP00016", + False, + False, + 0.99, + False, + False, + False, + False, + True, ), - ("SP00015", "SB00010", "SP00015|SP00016", False, False, 0.99, False, True, False, True), - ("SP00016", "SB00010", "SP00015|SP00016", False, False, 0.99, False, False, False, True), ] - return pd.DataFrame(data, columns=columns).set_index("Sample_ID") + return pd.DataFrame(data, columns=columns) @pytest.mark.parametrize( "contam,rep_discordant,num_removed", - [(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)], # call rate filtered + [(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)], ) def test_add_analytic_exclusion(fake_sample_qc, contam, rep_discordant, num_removed): pd.set_option("display.max_columns", None) @@ -322,7 +422,6 @@ def test_add_analytic_exclusion(fake_sample_qc, contam, rep_discordant, num_remo assert num_removed == fake_sample_qc.analytic_exclusion.sum() -# change these since I updated fake_sample_qc @pytest.mark.parametrize( "contam,rep_discordant,num_subjects", [(False, False, 9), (True, False, 9), (False, True, 8), (True, True, 8)],