Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cleaner data handling and improved logic for sample_qc_table.py (issue #324) #345

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 40 additions & 47 deletions src/cgr_gwas_qc/workflow/scripts/sample_qc_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _read_contam(file_name: Optional[Path], Sample_IDs: pd.Index) -> pd.DataFram
return pd.DataFrame(
index=Sample_IDs,
columns=["Contamination_Rate", "is_contaminated"],
).astype({"Contamination_Rate": pd.NA, "is_contaminated": pd.NA})
).astype({"Contamination_Rate": "float", "is_contaminated": "boolean"})

return (
agg_contamination.read(file_name)
Expand Down Expand Up @@ -464,12 +464,12 @@ def add_qc_columns(
remove_rep_discordant: bool,
) -> pd.DataFrame:
add_call_rate_flags(sample_qc)
_add_identifiler(sample_qc)
_add_analytic_exclusion(
sample_qc,
remove_contam,
remove_rep_discordant,
)
_add_identifiler(sample_qc)
_add_subject_representative(sample_qc)
_add_subject_dropped_from_study(sample_qc)

Expand Down Expand Up @@ -516,64 +516,53 @@ def reason_string(row: pd.Series) -> str:
def _retain_valid_discordant_replicates(
sample_qc: pd.DataFrame,
) -> pd.DataFrame:
"""Check and update the status of a pair of samples labeled as
discordant expected replicates.

This function verifies if the provided sample pair is labeled as
discordant expected replicates.
If they are, it checks for contamination or low call rate flags on
each sample.
If one of the samples is found to be contaminated or has a low call
rate, the function retains the non-contaminated and non-low-call-rate
sample, updating its status to remove the expected replicate label.
"""Updates the status of discordant expected replicates.

Given a pair of discordant expected replicates, it checks for contamination
or low call rate flags on each sample. If one sample is found to be
contaminated or has a low call rate, the function updates the
"is_discordant_replicate" status of the non-contaminated and
non-low-call-rate sample to "False". This way, it can be retained for
subject-level analysis.
"""

# Assuming sample_qc is your DataFrame
if "Sample_ID" in sample_qc.columns:
sample_qc = sample_qc.set_index("Sample_ID")

# Iterate through each sample in the DataFrame
for index, row in sample_qc.iterrows():
# Check if the sample is a discordant expected replicate
if row["is_discordant_replicate"]:
# Get the list of other samples it is discordant with
discordant_samples = row["replicate_ids"].split("|")

# Initialize flags for contamination and call rate issues
is_current_sample_low_call_rate = row["is_cr2_filtered"]
is_current_sample_low_call_rate = row["is_call_rate_filtered"]
is_current_sample_contaminated = (
row["is_contaminated"] if not pd.isna(row["is_contaminated"]) else False
# Treat pd.NA as not contaminated
row["is_contaminated"]
if not pd.isna(row["is_contaminated"])
else False
)

# Initialize a flag to track if all discordant samples have issues
# flag to track if all discordant samples have issues
all_other_samples_issue = True
discordant_samples = row["replicate_ids"].split("|")

# Check each discordant sample
for sample_id in discordant_samples:
if sample_id == row.name:
continue
if sample_id == row["Sample_ID"]:
continue # only look at other samples
else:
# Get the row for the discordant sample
discordant_row = sample_qc.loc[sample_id]
if not discordant_row.empty:
contaminated = (
discordant_row["is_contaminated"]
if not pd.isna(discordant_row["is_contaminated"])
else False
)
low_call_rate = discordant_row["is_cr2_filtered"]

# Check if the discordant sample is contaminated or has a low call rate
if not contaminated and not low_call_rate:
all_other_samples_issue = False
break

# If the current sample is not contaminated or low call rate
discordant_row = sample_qc[sample_qc["Sample_ID"] == sample_id].iloc[0]
low_call_rate = discordant_row["is_call_rate_filtered"]
contaminated = (
discordant_row["is_contaminated"]
if not pd.isna(discordant_row["is_contaminated"])
else False
)

# Check if the discordant sample is contaminated and/or has a low call rate
if not contaminated and not low_call_rate:
all_other_samples_issue = False
break

# Retain the current sample if not contaminated nor low call rate...
if not is_current_sample_contaminated and not is_current_sample_low_call_rate:
# If all other samples have issues, update the current sample's status
# and the other samples to have issues
if all_other_samples_issue:
sample_qc.at[index, "is_discordant_replicate"] = False

return sample_qc


Expand All @@ -596,22 +585,26 @@ def _add_analytic_exclusion(
"is_cr2_filtered": "Call Rate 2 Filtered",
}

_retain_valid_discordant_replicates(sample_qc)
sample_qc = _retain_valid_discordant_replicates(sample_qc)

if remove_contam:
exclusion_criteria["is_contaminated"] = "Contamination"

if remove_rep_discordant:
exclusion_criteria["is_discordant_replicate"] = "Replicate Discordance"

# adding this new column which is a boolean. Checks for any T in a series (e.g., {F, F, F, F, T}.any())
sample_qc["analytic_exclusion"] = sample_qc.reindex(exclusion_criteria.keys(), axis=1).any(
axis=1
)

# looking across the colums and adding up the Trues
sample_qc["num_analytic_exclusion"] = (
sample_qc.reindex(exclusion_criteria.keys(), axis=1).sum(axis=1).astype(int)
)
sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria)

# get the names of the columns that are True and return a "|" delimited string of dict values whose keys were true.
sample_qc["analytic_exclusion_reason"] = _get_reason(sample_qc, exclusion_criteria)
return sample_qc


Expand Down
131 changes: 115 additions & 16 deletions tests/workflow/scripts/test_sample_qc_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,95 @@ def fake_sample_qc() -> pd.DataFrame:
"Call_Rate_2",
"is_cr1_filtered",
"is_cr2_filtered",
"is_call_rate_filtered",
"is_contaminated",
"is_discordant_replicate",
]

data = [
("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False),
("SP00002", "SB00002", "", False, False, 0.82, False, True, False, False),
("SP00003", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, True, True),
("SP00004", "SB00003", "SP00003|SP00004", False, False, 0.99, False, False, False, True),
("SP00005", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True),
("SP00006", "SB00004", "SP00005|SP00006", False, False, 0.99, False, False, False, True),
("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False),
("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False),
("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False),
("SP00010", "SB00008", "SP00010|SP00011", False, False, 0.99, False, False, False, False),
("SP00011", "SB00008", "SP00010|SP00011", False, False, 0.94, False, False, False, False),
("SP00001", "SB00001", "", False, False, 0.99, False, False, False, False, False),
("SP00002", "SB00002", "", False, False, 0.82, False, True, True, False, False),
(
"SP00003",
"SB00003",
"SP00003|SP00004",
False,
False,
0.99,
False,
False,
False,
True,
True,
),
(
"SP00004",
"SB00003",
"SP00003|SP00004",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
(
"SP00005",
"SB00004",
"SP00005|SP00006",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
(
"SP00006",
"SB00004",
"SP00005|SP00006",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
("SP00007", "SB00005", "", False, False, 0.99, False, False, False, False, False),
("SP00008", "SB00006", "", False, False, 0.99, False, False, False, False, False),
("SP00009", "SB00007", "", False, False, 0.99, False, False, False, False, False),
(
"SP00010",
"SB00008",
"SP00010|SP00011",
False,
False,
0.99,
False,
False,
False,
False,
False,
),
(
"SP00011",
"SB00008",
"SP00010|SP00011",
False,
False,
0.94,
False,
False,
False,
False,
False,
),
(
"SP00012",
"SB00009",
Expand All @@ -281,6 +354,7 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00013",
Expand All @@ -293,6 +367,7 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00014",
Expand All @@ -305,24 +380,48 @@ def fake_sample_qc() -> pd.DataFrame:
False,
False,
False,
False,
),
(
"SP00015",
"SB00010",
"SP00015|SP00016",
False,
False,
0.99,
False,
True,
True,
False,
True,
),
(
"SP00016",
"SB00010",
"SP00015|SP00016",
False,
False,
0.99,
False,
False,
False,
False,
True,
),
("SP00015", "SB00010", "SP00015|SP00016", False, False, 0.99, False, True, False, True),
("SP00016", "SB00010", "SP00015|SP00016", False, False, 0.99, False, False, False, True),
]
return pd.DataFrame(data, columns=columns).set_index("Sample_ID")
return pd.DataFrame(data, columns=columns)


@pytest.mark.parametrize(
"contam,rep_discordant,num_removed",
[(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)], # call rate filtered
[(False, False, 2), (True, False, 3), (False, True, 5), (True, True, 5)],
)
def test_add_analytic_exclusion(fake_sample_qc, contam, rep_discordant, num_removed):
pd.set_option("display.max_columns", None)
sample_qc_table._add_analytic_exclusion(fake_sample_qc, contam, rep_discordant)
assert num_removed == fake_sample_qc.analytic_exclusion.sum()


# change these since I updated fake_sample_qc
@pytest.mark.parametrize(
"contam,rep_discordant,num_subjects",
[(False, False, 9), (True, False, 9), (False, True, 8), (True, True, 8)],
Expand Down
Loading