Skip to content

Commit

Permalink
adjustments for RAP
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Oct 2, 2023
1 parent 980abf0 commit dc6232e
Show file tree
Hide file tree
Showing 16 changed files with 824 additions and 279 deletions.
74 changes: 48 additions & 26 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def __init__(

if grouping_level is not None:
if grouping_level == "gene":
self.grouping_column = "gene_ids"
self.grouping_column = "gene_id"
elif grouping_level == "exon":
self.grouping_column = "exon_ids"
self.grouping_column = "exon_id"
else:
raise ValueError(f"Unknown aggregation level {grouping_level}")
else:
Expand Down Expand Up @@ -491,32 +491,49 @@ def setup_variants(
" or min_common_af must be specified"
)

logger.debug(" Reading variant dataframe")
logger.debug(f" Reading variant dataframe {self.variant_filename}")
variants = dd.read_parquet(self.variant_filename, engine="pyarrow").compute()
variants = variants.set_index("id", drop=False)
variants = variants.drop(columns="matrix_index", errors="ignore")

logger.debug(" Subsetting variants")
if self.variants_to_keep is not None:
logger.info("Selecting subset of variants as defined by variants_to_keep")
variants = variants.loc[self.variants_to_keep]
logger.debug(" Filtering variants")
if min_common_variant_count is not None:
logger.debug(" Selecting common variants by MAC")
mask = (variants["count"] >= min_common_variant_count) & (
variants["count"] <= self.n_samples - min_common_variant_count
)
mask = mask.to_numpy()
logger.debug(f' {mask.sum()} variants "common" by count filter')
elif min_common_af is not None:
logger.debug(" Selecting common variants by MAF")
af_col, af_threshold = list(min_common_af.items())[0]
variants_with_af = safe_merge(
af_df = self.annotation_df.reset_index()[["id", af_col]].drop_duplicates()
try:
assert af_df["id"].unique().shape[0] == len(af_df)
except:
raise ValueError(
"Inconsistent allele frequencies in annotation dataframe"
)
# variants_with_af = safe_merge(
# variants[["id"]].reset_index(drop=True),
# af_df,
# )
variants_with_af = pd.merge(
variants[["id"]].reset_index(drop=True),
self.annotation_df[[af_col]].reset_index(),
af_df,
how="left",
validate="1:1",
)
assert np.all(
variants_with_af["id"].to_numpy() == variants["id"].to_numpy()
)
mask = (variants_with_af[af_col] >= af_threshold) & (
variants_with_af[af_col] <= 1 - af_threshold
mask = (
~variants_with_af[af_col].isna()
& (variants_with_af[af_col] >= af_threshold)
& (variants_with_af[af_col] <= 1 - af_threshold)
)
mask = mask.to_numpy()
del variants_with_af
Expand All @@ -536,6 +553,8 @@ def setup_variants(
rare_variant_mask = ~mask
chromosome_mask = variants["chrom"].isin(self.chromosomes).to_numpy()
additional_mask = chromosome_mask
for c in exclude_variant_cols:
additional_mask &= ~variants[c].to_numpy()
if self.exons_to_keep is not None:
raise NotImplementedError("The variant dataframes have outdated exon_ids")
additional_mask &= (
Expand All @@ -551,28 +570,34 @@ def setup_variants(
.to_numpy()
)
if self.gene_file is not None:
logger.debug(" Selecting variants by gene")
genes = set(pd.read_parquet(self.gene_file, columns=["id"])["id"])
logger.debug(f" Retaining {len(genes)} genes from {self.gene_file}")
variants_with_gene_ids = safe_merge(
variants[["id"]].reset_index(drop=True),
self.annotation_df[["gene_ids"]].reset_index(),
)
assert np.all(
variants_with_gene_ids["id"].to_numpy() == variants["id"].to_numpy()
)
additional_mask &= (
variants_with_gene_ids["gene_ids"]
.apply(lambda x: len(set(x) & genes) != 0)
# variants_with_gene_ids = safe_merge(
# variants[["id"]].reset_index(drop=True),
# self.annotation_df[["gene_id"]].reset_index(),
# validate = "1:m"
# )
# assert np.all(
# variants_with_gene_ids["id"].to_numpy() == variants["id"].to_numpy()
# )
# additional_mask &= variants_with_gene_ids["gene_id"].isin(genes).to_numpy()
# del variants_with_gene_ids
ids_to_keep = (
self.annotation_df.reset_index()
.query("gene_id in @genes")["id"]
.to_numpy()
)
del variants_with_gene_ids
additional_mask &= variants["id"].isin(ids_to_keep)
if self.gene_types_to_keep is not None:
raise NotImplementedError
additional_mask &= (
variants["gene_types"]
.apply(lambda x: len(set(x) & self.gene_types_to_keep) != 0)
.to_numpy()
)
if self.ignore_by_annotation is not None:
raise NotImplementedError
for col, val in self.ignore_by_annotation:
if self.annotation_df[col].dtype == np.dtype("object"):
additional_mask &= (
Expand All @@ -592,9 +617,8 @@ def setup_variants(
and self.gene_file is None
and self.gene_types_to_keep is None
):
rare_variant_mask &= (
variants["gene_ids"].apply(lambda x: len(x) > 0).to_numpy()
)
logger.debug(" Selecting variants by gene type")
rare_variant_mask &= variants["gene_id"].notna().to_numpy()

variants["rare_variant_mask"] = rare_variant_mask

Expand All @@ -604,9 +628,7 @@ def setup_variants(
common_variant_mask &= ~af_mask
common_variant_mask &= additional_mask
if self.group_common:
common_variant_mask &= (
variants["gene_ids"].apply(lambda x: len(x) > 0).to_numpy()
)
common_variant_mask &= variants["gene_id"].notna().to_numpy()

variants["matrix_index"] = -1
matrix_index_mask = common_variant_mask
Expand Down Expand Up @@ -665,7 +687,7 @@ def setup_common_groups(self):

common_variant_groups = common_variant_groups.explode(self.grouping_column)
common_variant_groups = common_variant_groups[
common_variant_groups["gene_ids"].notna()
common_variant_groups["gene_id"].notna()
]

if self.return_sparse:
Expand Down
Loading

0 comments on commit dc6232e

Please sign in to comment.