Skip to content

Commit

Permalink
Revert "Squashed commit of the following:"
Browse files Browse the repository at this point in the history
This reverts commit 8ad84ea.
  • Loading branch information
endast committed May 14, 2024
1 parent 8ad84ea commit 588b414
Show file tree
Hide file tree
Showing 98 changed files with 342 additions and 1,956 deletions.
21 changes: 0 additions & 21 deletions .github/workflows/test-runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,3 @@ jobs:
- name: Run pytest preprocessing
run: pytest -v ${{ github.workspace }}/tests/preprocessing
shell: micromamba-shell {0}

DeepRVAT-Tests-Runner-Annotations:
runs-on: ubuntu-latest
steps:

- name: Check out repository code
uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v1.8.0
with:
environment-name: deeprvat-annotation-gh-action
environment-file: ${{ github.workspace }}/deeprvat_annotations.yml
cache-environment: true
cache-downloads: true

- name: Install DeepRVAT
run: pip install -e ${{ github.workspace }}
shell: micromamba-shell {0}

- name: Run pytest annotations
run: pytest -v ${{ github.workspace }}/tests/annotations
shell: micromamba-shell {0}
83 changes: 53 additions & 30 deletions deeprvat/annotations/annotations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import pickle
import random
import sys
Expand Down Expand Up @@ -1082,7 +1080,7 @@ def aggregate_abscores(
delayed(process_chunk)(
i, abs_splice_res_dir, tissues_to_exclude, tissue_agg_function, ca_shortened
)
for i in tqdm(sorted(os.listdir(abs_splice_res_dir)))
for i in tqdm(os.listdir(abs_splice_res_dir))
)
all_absplice_scores = list(output_generator)

Expand Down Expand Up @@ -1362,6 +1360,45 @@ def merge_deepsea_pcas(
merged.to_parquet(out_file)


@cli.command()
@click.argument("in_variants", type=click.Path(exists=True))
@click.argument("out_variants", type=click.Path())
def process_annotations(in_variants: str, out_variants: str):
"""
Process variant annotations, filter for canonical variants, and aggregate consequences.
Parameters:
- in_variants (str): Path to the input variant annotation file in parquet format.
- out_variants (str): Path to save the processed variant annotation file in parquet format.
Returns:
None
Notes:
- The function reads the input variant annotation file.
- It filters for canonical variants where the 'CANONICAL' column is equal to 'YES'.
- The 'Gene' column is renamed to 'gene_id'.
- Consequences for different alleles are aggregated by combining the variant ID with the gene ID.
- The processed variant annotations are saved to the specified output file.
Example:
$ python annotations.py process_annotations input_variants.parquet output_variants.parquet
"""
variant_path = Path(in_variants)
variants = pd.read_parquet(variant_path)

logger.info("filtering for canonical variants")

variants = variants.loc[variants.CANONICAL == "YES"]
variants.rename(columns={"Gene": "gene_id"}, inplace=True)

logger.info("aggregating consequences for different alleles")

# combining variant id with gene id
variants["censequence_id"] = variants["id"].astype(str) + variants["gene_id"]
variants.to_parquet(out_variants, compression="zstd")


def process_chunk_addids(chunk: pd.DataFrame, variants: pd.DataFrame) -> pd.DataFrame:
"""
Process a chunk of data by adding identifiers from a variants dataframe.
Expand Down Expand Up @@ -1470,14 +1507,16 @@ def add_ids(annotation_file: str, variant_file: str, njobs: int, out_file: str):
@cli.command()
@click.argument("annotation_file", type=click.Path(exists=True))
@click.argument("variant_file", type=click.Path(exists=True))
@click.argument("njobs", type=int)
@click.argument("out_file", type=click.Path())
def add_ids_dask(annotation_file: str, variant_file: str, out_file: str):
def add_ids_dask(annotation_file: str, variant_file: str, njobs: int, out_file: str):
"""
Add identifiers from a variant file to an annotation file using Dask and save the result.
Parameters:
- annotation_file (str): Path to the input annotation file in Parquet format.
- variant_file (str): Path to the input variant file in Parquet format.
- njobs (int): Number of parallel jobs to process the data.
- out_file (str): Path to save the processed data in Parquet format.
Returns:
Expand All @@ -1493,7 +1532,7 @@ def add_ids_dask(annotation_file: str, variant_file: str, out_file: str):
$ python annotations.py add_ids_dask annotation_data.parquet variant_data.parquet 4 processed_data.parquet
"""
data = dd.read_parquet(annotation_file, blocksize=25e9)
all_variants = pd.read_parquet(variant_file)
all_variants = pd.read_table(variant_file)
data = data.rename(
columns={
"#CHROM": "chrom",
Expand Down Expand Up @@ -1666,7 +1705,7 @@ def merge_annotations(
logger.info("load variant_file")

logger.info(f"reading in {variant_file}")
variants = pd.read_parquet(variant_file)
variants = pd.read_csv(variant_file, sep="\t")

logger.info("merge vep to variants M:1")
ca = vep_df.merge(
Expand Down Expand Up @@ -1738,23 +1777,7 @@ def process_vep(
vcf_file, names=["chrom", "pos", "#Uploaded_variation", "ref", "alt"]
)
if "#Uploaded_variation" in vep_file.columns:
vep_file = vep_file.merge(vcf_df, on="#Uploaded_variation", how="left")
if vep_file.chrom.isna().sum() > 0:
vep_file.loc[vep_file.chrom.isna(), ["chrom", "pos", "ref", "alt"]] = (
vep_file[vep_file["chrom"].isna()]["#Uploaded_variation"]
.str.replace("_", ":")
.str.replace("/", ":")
.str.split(":", expand=True)
.values
)
assert vep_file.chrom.isna().sum() == 0
assert vep_file.pos.isna().sum() == 0
assert vep_file.ref.isna().sum() == 0
assert vep_file.alt.isna().sum() == 0
assert (
vep_file[["chrom", "pos", "ref", "alt"]].drop_duplicates().shape
== vcf_df[["chrom", "pos", "ref", "alt"]].drop_duplicates().shape
)
vep_file = vep_file.merge(vcf_df, on="#Uploaded_variation")

if "pos" in vep_file.columns:
vep_file["pos"] = vep_file["pos"].astype(int)
Expand Down Expand Up @@ -1956,7 +1979,7 @@ def get_af_from_gt(genotype_file: str, variants_filepath: str, out_file: str):
"""
import h5py

variants = pd.read_parquet(variants_filepath)
variants = pd.read_table(variants_filepath)
max_variant_id = variants["id"].max()

logger.info("Computing allele frequencies")
Expand Down Expand Up @@ -2019,19 +2042,19 @@ def calculate_maf(annotations_path: str, out_file: str):


@cli.command()
@click.argument("gene_id_file", type=click.Path(exists=True))
@click.argument("protein_id_file", type=click.Path(exists=True))
@click.argument("annotations_path", type=click.Path(exists=True))
@click.argument("out_file", type=click.Path())
def add_gene_ids(gene_id_file: str, annotations_path: str, out_file: str):
def add_protein_ids(protein_id_file: str, annotations_path: str, out_file: str):
"""
Add gene IDs to the annotations based on gene ID mapping file.
Add protein IDs to the annotations based on protein ID mapping file.
Parameters:
- gene_id_file (str): Path to the gene ID mapping file.
- protein_id_file (str): Path to the protein ID mapping file.
- annotations_path (str): Path to the annotations file.
- out_file (str): Path to the output file to save the annotations with protein IDs.
"""
genes = pd.read_parquet(gene_id_file)
genes = pd.read_parquet(protein_id_file)
genes[["gene_base", "feature"]] = genes["gene"].str.split(".", expand=True)
genes.drop(columns=["feature", "gene", "gene_name", "gene_type"], inplace=True)
genes.rename(columns={"id": "gene_id"}, inplace=True)
Expand All @@ -2046,7 +2069,7 @@ def add_gene_ids(gene_id_file: str, annotations_path: str, out_file: str):
@cli.command()
@click.argument("gtf_filepath", type=click.Path(exists=True))
@click.argument("out_file", type=click.Path())
def create_gene_id_file(gtf_filepath: str, out_file: str):
def create_protein_id_file(gtf_filepath: str, out_file: str):
"""
Create a protein ID mapping file from the GTF file.
Expand Down
11 changes: 4 additions & 7 deletions deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def spread_config(
cv_path = f"{config_template['cv_path']}/{n_folds}_fold"
for module in data_modules:
config = copy.deepcopy(config_template)
data_slots = DATA_SLOT_DICT[module]
for data_slot in data_slots:
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config[data_slot]["dataset_config"]["sample_file"] = sample_file
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file

if (module == "deeprvat") | (module == "deeprvat_pretrained"):
logger.info("Writing baseline directories")
Expand Down Expand Up @@ -91,8 +89,7 @@ def generate_test_config(input_config, out_file, fold, n_folds):
split = "test"
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
for data_slot in DATA_SLOT_DICT["deeprvat"]:
config[data_slot]["dataset_config"]["sample_file"] = sample_file
config["sample_file"] = sample_file
with open(out_file, "w") as f:
yaml.dump(config, f)

Expand Down
22 changes: 14 additions & 8 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def setup_phenotypes(
# account for the fact that genotypes.h5 and phenotype_df can have different
# orders of their samples
self.index_map_geno, _ = get_matched_sample_indices(
samples_gt.astype(int), self.samples.astype(int)
samples_gt.astype(str), self.samples.astype(str)
)
# get_matched_sample_indices is a much, much faster implementation of the code below
# self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)]
Expand Down Expand Up @@ -614,13 +614,20 @@ def setup_variants(
"Annotation dataframe has inconsistent allele frequency values"
)
variants_with_af = safe_merge(
variants[["id"]].reset_index(drop=True), af_annotation
variants[["id"]].reset_index(drop=True), af_annotation, how="left"
)
assert np.all(
variants_with_af["id"].to_numpy() == variants["id"].to_numpy()
)
mask = (variants_with_af[af_col] >= af_threshold) & (
variants_with_af[af_col] <= 1 - af_threshold
af_isna = variants_with_af[af_col].isna()
if af_isna.sum() > 0:
logger.warning(
f"Dropping {af_isna.sum()} variants missing from annotation dataframe"
)
mask = (
(~af_isna)
& (variants_with_af[af_col] >= af_threshold)
& (variants_with_af[af_col] <= 1 - af_threshold)
)
mask = mask.to_numpy()
del variants_with_af
Expand Down Expand Up @@ -931,11 +938,10 @@ def get_metadata(self) -> Dict[str, Any]:
result = {
"variant_metadata": self.variants[
["id", "common_variant_mask", "rare_variant_mask", "matrix_index"]
]
],
"samples": self.samples,
}
if self.use_rare_variants:
if hasattr(self.rare_embedding, "get_metadata"):
result.update(
{"rare_embedding_metadata": self.rare_embedding.get_metadata()}
)
result["rare_embedding_metadata"] = self.rare_embedding.get_metadata()
return result
60 changes: 38 additions & 22 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import click
import dask.dataframe as dd
Expand Down Expand Up @@ -115,9 +115,7 @@ def cli():

def make_dataset_(
config: Dict,
debug: bool = False,
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
Expand Down Expand Up @@ -149,29 +147,17 @@ def make_dataset_(
**copy.deepcopy(data_config["dataset_config"]),
)

restrict_samples = config.get("restrict_samples", None)
if debug:
logger.info("Debug flag set; Using only 1000 samples")
ds = Subset(ds, range(1_000))
elif samples is not None:
ds = Subset(ds, samples)
elif restrict_samples is not None:
ds = Subset(ds, range(restrict_samples))

return ds


@cli.command()
@click.option("--debug", is_flag=True)
@click.option("--data-key", type=str, default="data")
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
@click.argument("config-file", type=click.Path(exists=True, path_type=Path))
@click.argument("out-file", type=click.Path(path_type=Path))
def make_dataset(data_key: str, config_file: Path, out_file: Path):
"""
Create a dataset based on the provided configuration and save to a pickle file.
:param debug: Flag for debugging.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param config_file: Path to the configuration file.
Expand All @@ -183,7 +169,7 @@ def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
with open(config_file) as f:
config = yaml.safe_load(f)

ds = make_dataset_(config, debug=debug, data_key=data_key)
ds = make_dataset_(config, data_key=data_key)

with open(out_file, "wb") as f:
pickle.dump(ds, f)
Expand Down Expand Up @@ -236,6 +222,8 @@ def compute_burdens_(
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
logger.setLevel(logging.INFO)

if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand All @@ -247,10 +235,38 @@ def compute_burdens_(

data_config = config["data"]

ds_full = ds.dataset if isinstance(ds, Subset) else ds
if "sample_file" in config:
sample_file = Path(config["sample_file"])
logger.info(f"Using samples from {sample_file}")
if sample_file.suffix == ".pkl":
with open(sample_file, "rb") as f:
sample_ids = np.array(pickle.load(f))
elif sample_file.suffix == ".zarr":
sample_ids = zarr.load(sample_file)
elif sample_file.suffix == ".npy":
sample_ids = np.load(sample_file)
else:
raise ValueError("Unknown file type for sample_file")
ds_samples = ds.get_metadata()["samples"]
sample_indices = np.where(
np.isin(ds_samples.astype(str), sample_ids.astype(str))
)[0]
if debug:
sample_indices = sample_indices[:1000]
elif debug:
sample_indices = np.arange(min(1000, len(ds)))
else:
sample_indices = np.arange(len(ds))

logger.info(
f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}"
)
ds = Subset(ds, sample_indices)

ds_full = ds.dataset # if isinstance(ds, Subset) else ds
collate_fn = getattr(ds_full, "collate_fn", None)
n_total_samples = len(ds)
ds.rare_embedding.skip_embedding = skip_burdens
ds_full.rare_embedding.skip_embedding = skip_burdens

if chunk is not None:
if n_chunks is None:
Expand Down Expand Up @@ -903,7 +919,7 @@ def compute_burdens(
with open(dataset_file, "rb") as f:
dataset = pickle.load(f)
else:
dataset = make_dataset_(config)
dataset = make_dataset_(data_config)

if torch.cuda.is_available():
logger.info("Using GPU")
Expand Down
Loading

0 comments on commit 588b414

Please sign in to comment.