Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 5715e99
Author: Marcel Mück <mueckm1@gmail.com>
Date:   Tue May 14 09:30:02 2024 +0200

    Fix omitted variants, correct test cases.  (#84)

    * addert assert to make sure no variants are omitted when merging annotations

    * changed expected files, sorts input files in aggregate_abscores

    * fixup! Format Python code with psf/black pull_request

    ---------

    Co-authored-by: “Marcel-Mueck” <“mueckm1@gmail.com”>
    Co-authored-by: PMBio <PMBio@users.noreply.github.com>

commit 6f04801
Author: Kayla Meyer <129152803+meyerkm@users.noreply.github.com>
Date:   Mon May 13 14:49:59 2024 +0200

    Revert "Feature streamline cv training" (#79)

    * Revert "Feature streamline cv training (#71)"

    This reverts commit d9ede28.

    * Add in association-only flag in rule config. Add in rule evaluate to complete cv-training-association-testing pipeline.

    * n_avg_chunks = 1

    ---------

    Co-authored-by: Magnus Wahlberg <endast@gmail.com>
    Co-authored-by: Eva Holtkamp <eva.holtkamp@gmx.de>

commit 99ed88e
Author: Marcel Mück <mueckm1@gmail.com>
Date:   Mon May 13 13:14:16 2024 +0200

    Feature/annotation tests (#82)

    * Create test skeleton for annotations

    * added annotation script to deeprvat setup, updated docs to reflect that change, added first test for annotation pipeline.

    * added tests for annotation pipeline, variant file now parquet

    * added data for test

    * added data for tests

    * Test for merge_deeprsea_pcas function

    * added test for absplice score aggregation

    * added robustness for mixed entry types in ID column of input vcf, created test case

    * added further tests

    * added pyranges to environment

    * Update absplice.yaml

    * Update environment_spliceai_rocksdb.yaml

    ---------

    Co-authored-by: Magnus Wahlberg <endast@gmail.com>
    Co-authored-by: Mück <m991k@b260-pc003.inet.dkfz-heidelberg.de>
    Co-authored-by: PMBio <PMBio@users.noreply.github.com>
  • Loading branch information
endast committed May 14, 2024
1 parent c9abdfc commit 8ad84ea
Show file tree
Hide file tree
Showing 98 changed files with 1,956 additions and 342 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test-runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ jobs:
- name: Run pytest preprocessing
run: pytest -v ${{ github.workspace }}/tests/preprocessing
shell: micromamba-shell {0}

DeepRVAT-Tests-Runner-Annotations:
runs-on: ubuntu-latest
steps:

- name: Check out repository code
uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v1.8.0
with:
environment-name: deeprvat-annotation-gh-action
environment-file: ${{ github.workspace }}/deeprvat_annotations.yml
cache-environment: true
cache-downloads: true

- name: Install DeepRVAT
run: pip install -e ${{ github.workspace }}
shell: micromamba-shell {0}

- name: Run pytest annotations
run: pytest -v ${{ github.workspace }}/tests/annotations
shell: micromamba-shell {0}
83 changes: 30 additions & 53 deletions deeprvat/annotations/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import pickle
import random
import sys
Expand Down Expand Up @@ -1080,7 +1082,7 @@ def aggregate_abscores(
delayed(process_chunk)(
i, abs_splice_res_dir, tissues_to_exclude, tissue_agg_function, ca_shortened
)
for i in tqdm(os.listdir(abs_splice_res_dir))
for i in tqdm(sorted(os.listdir(abs_splice_res_dir)))
)
all_absplice_scores = list(output_generator)

Expand Down Expand Up @@ -1360,45 +1362,6 @@ def merge_deepsea_pcas(
merged.to_parquet(out_file)


@cli.command()
@click.argument("in_variants", type=click.Path(exists=True))
@click.argument("out_variants", type=click.Path())
def process_annotations(in_variants: str, out_variants: str):
"""
Process variant annotations, filter for canonical variants, and aggregate consequences.
Parameters:
- in_variants (str): Path to the input variant annotation file in parquet format.
- out_variants (str): Path to save the processed variant annotation file in parquet format.
Returns:
None
Notes:
- The function reads the input variant annotation file.
- It filters for canonical variants where the 'CANONICAL' column is equal to 'YES'.
- The 'Gene' column is renamed to 'gene_id'.
- Consequences for different alleles are aggregated by combining the variant ID with the gene ID.
- The processed variant annotations are saved to the specified output file.
Example:
$ python annotations.py process_annotations input_variants.parquet output_variants.parquet
"""
variant_path = Path(in_variants)
variants = pd.read_parquet(variant_path)

logger.info("filtering for canonical variants")

variants = variants.loc[variants.CANONICAL == "YES"]
variants.rename(columns={"Gene": "gene_id"}, inplace=True)

logger.info("aggregating consequences for different alleles")

# combining variant id with gene id
variants["censequence_id"] = variants["id"].astype(str) + variants["gene_id"]
variants.to_parquet(out_variants, compression="zstd")


def process_chunk_addids(chunk: pd.DataFrame, variants: pd.DataFrame) -> pd.DataFrame:
"""
Process a chunk of data by adding identifiers from a variants dataframe.
Expand Down Expand Up @@ -1507,16 +1470,14 @@ def add_ids(annotation_file: str, variant_file: str, njobs: int, out_file: str):
@cli.command()
@click.argument("annotation_file", type=click.Path(exists=True))
@click.argument("variant_file", type=click.Path(exists=True))
@click.argument("njobs", type=int)
@click.argument("out_file", type=click.Path())
def add_ids_dask(annotation_file: str, variant_file: str, njobs: int, out_file: str):
def add_ids_dask(annotation_file: str, variant_file: str, out_file: str):
"""
Add identifiers from a variant file to an annotation file using Dask and save the result.
Parameters:
- annotation_file (str): Path to the input annotation file in Parquet format.
- variant_file (str): Path to the input variant file in Parquet format.
- njobs (int): Number of parallel jobs to process the data.
- out_file (str): Path to save the processed data in Parquet format.
Returns:
Expand All @@ -1532,7 +1493,7 @@ def add_ids_dask(annotation_file: str, variant_file: str, njobs: int, out_file:
$ python annotations.py add_ids_dask annotation_data.parquet variant_data.parquet 4 processed_data.parquet
"""
data = dd.read_parquet(annotation_file, blocksize=25e9)
all_variants = pd.read_table(variant_file)
all_variants = pd.read_parquet(variant_file)
data = data.rename(
columns={
"#CHROM": "chrom",
Expand Down Expand Up @@ -1705,7 +1666,7 @@ def merge_annotations(
logger.info("load variant_file")

logger.info(f"reading in {variant_file}")
variants = pd.read_csv(variant_file, sep="\t")
variants = pd.read_parquet(variant_file)

logger.info("merge vep to variants M:1")
ca = vep_df.merge(
Expand Down Expand Up @@ -1777,7 +1738,23 @@ def process_vep(
vcf_file, names=["chrom", "pos", "#Uploaded_variation", "ref", "alt"]
)
if "#Uploaded_variation" in vep_file.columns:
vep_file = vep_file.merge(vcf_df, on="#Uploaded_variation")
vep_file = vep_file.merge(vcf_df, on="#Uploaded_variation", how="left")
if vep_file.chrom.isna().sum() > 0:
vep_file.loc[vep_file.chrom.isna(), ["chrom", "pos", "ref", "alt"]] = (
vep_file[vep_file["chrom"].isna()]["#Uploaded_variation"]
.str.replace("_", ":")
.str.replace("/", ":")
.str.split(":", expand=True)
.values
)
assert vep_file.chrom.isna().sum() == 0
assert vep_file.pos.isna().sum() == 0
assert vep_file.ref.isna().sum() == 0
assert vep_file.alt.isna().sum() == 0
assert (
vep_file[["chrom", "pos", "ref", "alt"]].drop_duplicates().shape
== vcf_df[["chrom", "pos", "ref", "alt"]].drop_duplicates().shape
)

if "pos" in vep_file.columns:
vep_file["pos"] = vep_file["pos"].astype(int)
Expand Down Expand Up @@ -1979,7 +1956,7 @@ def get_af_from_gt(genotype_file: str, variants_filepath: str, out_file: str):
"""
import h5py

variants = pd.read_table(variants_filepath)
variants = pd.read_parquet(variants_filepath)
max_variant_id = variants["id"].max()

logger.info("Computing allele frequencies")
Expand Down Expand Up @@ -2042,19 +2019,19 @@ def calculate_maf(annotations_path: str, out_file: str):


@cli.command()
@click.argument("protein_id_file", type=click.Path(exists=True))
@click.argument("gene_id_file", type=click.Path(exists=True))
@click.argument("annotations_path", type=click.Path(exists=True))
@click.argument("out_file", type=click.Path())
def add_protein_ids(protein_id_file: str, annotations_path: str, out_file: str):
def add_gene_ids(gene_id_file: str, annotations_path: str, out_file: str):
"""
Add protein IDs to the annotations based on protein ID mapping file.
Add gene IDs to the annotations based on gene ID mapping file.
Parameters:
- protein_id_file (str): Path to the protein ID mapping file.
- gene_id_file (str): Path to the gene ID mapping file.
- annotations_path (str): Path to the annotations file.
- out_file (str): Path to the output file to save the annotations with protein IDs.
"""
genes = pd.read_parquet(protein_id_file)
genes = pd.read_parquet(gene_id_file)
genes[["gene_base", "feature"]] = genes["gene"].str.split(".", expand=True)
genes.drop(columns=["feature", "gene", "gene_name", "gene_type"], inplace=True)
genes.rename(columns={"id": "gene_id"}, inplace=True)
Expand All @@ -2069,7 +2046,7 @@ def add_protein_ids(protein_id_file: str, annotations_path: str, out_file: str):
@cli.command()
@click.argument("gtf_filepath", type=click.Path(exists=True))
@click.argument("out_file", type=click.Path())
def create_protein_id_file(gtf_filepath: str, out_file: str):
def create_gene_id_file(gtf_filepath: str, out_file: str):
"""
Create a protein ID mapping file from the GTF file.
Expand Down
11 changes: 7 additions & 4 deletions deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def spread_config(
cv_path = f"{config_template['cv_path']}/{n_folds}_fold"
for module in data_modules:
config = copy.deepcopy(config_template)
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file
data_slots = DATA_SLOT_DICT[module]
for data_slot in data_slots:
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config[data_slot]["dataset_config"]["sample_file"] = sample_file

if (module == "deeprvat") | (module == "deeprvat_pretrained"):
logger.info("Writing baseline directories")
Expand Down Expand Up @@ -89,7 +91,8 @@ def generate_test_config(input_config, out_file, fold, n_folds):
split = "test"
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file
for data_slot in DATA_SLOT_DICT["deeprvat"]:
config[data_slot]["dataset_config"]["sample_file"] = sample_file
with open(out_file, "w") as f:
yaml.dump(config, f)

Expand Down
22 changes: 8 additions & 14 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def setup_phenotypes(
# account for the fact that genotypes.h5 and phenotype_df can have different
# orders of their samples
self.index_map_geno, _ = get_matched_sample_indices(
samples_gt.astype(str), self.samples.astype(str)
samples_gt.astype(int), self.samples.astype(int)
)
# get_matched_sample_indices is a much, much faster implementation of the code below
# self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)]
Expand Down Expand Up @@ -614,20 +614,13 @@ def setup_variants(
"Annotation dataframe has inconsistent allele frequency values"
)
variants_with_af = safe_merge(
variants[["id"]].reset_index(drop=True), af_annotation, how="left"
variants[["id"]].reset_index(drop=True), af_annotation
)
assert np.all(
variants_with_af["id"].to_numpy() == variants["id"].to_numpy()
)
af_isna = variants_with_af[af_col].isna()
if af_isna.sum() > 0:
logger.warning(
f"Dropping {af_isna.sum()} variants missing from annotation dataframe"
)
mask = (
(~af_isna)
& (variants_with_af[af_col] >= af_threshold)
& (variants_with_af[af_col] <= 1 - af_threshold)
mask = (variants_with_af[af_col] >= af_threshold) & (
variants_with_af[af_col] <= 1 - af_threshold
)
mask = mask.to_numpy()
del variants_with_af
Expand Down Expand Up @@ -938,10 +931,11 @@ def get_metadata(self) -> Dict[str, Any]:
result = {
"variant_metadata": self.variants[
["id", "common_variant_mask", "rare_variant_mask", "matrix_index"]
],
"samples": self.samples,
]
}
if self.use_rare_variants:
if hasattr(self.rare_embedding, "get_metadata"):
result["rare_embedding_metadata"] = self.rare_embedding.get_metadata()
result.update(
{"rare_embedding_metadata": self.rare_embedding.get_metadata()}
)
return result
60 changes: 22 additions & 38 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import click
import dask.dataframe as dd
Expand Down Expand Up @@ -115,7 +115,9 @@ def cli():

def make_dataset_(
config: Dict,
debug: bool = False,
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
Expand Down Expand Up @@ -147,17 +149,29 @@ def make_dataset_(
**copy.deepcopy(data_config["dataset_config"]),
)

restrict_samples = config.get("restrict_samples", None)
if debug:
logger.info("Debug flag set; Using only 1000 samples")
ds = Subset(ds, range(1_000))
elif samples is not None:
ds = Subset(ds, samples)
elif restrict_samples is not None:
ds = Subset(ds, range(restrict_samples))

return ds


@cli.command()
@click.option("--debug", is_flag=True)
@click.option("--data-key", type=str, default="data")
@click.argument("config-file", type=click.Path(exists=True, path_type=Path))
@click.argument("out-file", type=click.Path(path_type=Path))
def make_dataset(data_key: str, config_file: Path, out_file: Path):
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
"""
Create a dataset based on the provided configuration and save to a pickle file.
:param debug: Flag for debugging.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param config_file: Path to the configuration file.
Expand All @@ -169,7 +183,7 @@ def make_dataset(data_key: str, config_file: Path, out_file: Path):
with open(config_file) as f:
config = yaml.safe_load(f)

ds = make_dataset_(config, data_key=data_key)
ds = make_dataset_(config, debug=debug, data_key=data_key)

with open(out_file, "wb") as f:
pickle.dump(ds, f)
Expand Down Expand Up @@ -222,8 +236,6 @@ def compute_burdens_(
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
logger.setLevel(logging.INFO)

if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand All @@ -235,38 +247,10 @@ def compute_burdens_(

data_config = config["data"]

if "sample_file" in config:
sample_file = Path(config["sample_file"])
logger.info(f"Using samples from {sample_file}")
if sample_file.suffix == ".pkl":
with open(sample_file, "rb") as f:
sample_ids = np.array(pickle.load(f))
elif sample_file.suffix == ".zarr":
sample_ids = zarr.load(sample_file)
elif sample_file.suffix == ".npy":
sample_ids = np.load(sample_file)
else:
raise ValueError("Unknown file type for sample_file")
ds_samples = ds.get_metadata()["samples"]
sample_indices = np.where(
np.isin(ds_samples.astype(str), sample_ids.astype(str))
)[0]
if debug:
sample_indices = sample_indices[:1000]
elif debug:
sample_indices = np.arange(min(1000, len(ds)))
else:
sample_indices = np.arange(len(ds))

logger.info(
f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}"
)
ds = Subset(ds, sample_indices)

ds_full = ds.dataset # if isinstance(ds, Subset) else ds
ds_full = ds.dataset if isinstance(ds, Subset) else ds
collate_fn = getattr(ds_full, "collate_fn", None)
n_total_samples = len(ds)
ds_full.rare_embedding.skip_embedding = skip_burdens
ds.rare_embedding.skip_embedding = skip_burdens

if chunk is not None:
if n_chunks is None:
Expand Down Expand Up @@ -919,7 +903,7 @@ def compute_burdens(
with open(dataset_file, "rb") as f:
dataset = pickle.load(f)
else:
dataset = make_dataset_(data_config)
dataset = make_dataset_(config)

if torch.cuda.is_available():
logger.info("Using GPU")
Expand Down
Loading

0 comments on commit 8ad84ea

Please sign in to comment.