Skip to content

Commit

Permalink
pipeline and script to create deeprvat_config.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerkm committed May 22, 2024
1 parent 4560460 commit c9d027e
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 1 deletion.
150 changes: 149 additions & 1 deletion deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import yaml

from deeprvat.deeprvat.evaluate import pval_correction
import deeprvat.deeprvat as deeprvat_dir
import pretrained_models as pretrained_dir
import os
from copy import deepcopy

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
Expand All @@ -23,6 +27,150 @@
def cli():
pass

@cli.command()
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("output-dir", type=click.Path(), default='.')
def create_main_config(
config_file: str,
output_dir: str,
):
"""
Generates the necessary deeprvat_config.yaml file for running all pipelines.
This function expects inputs as shown in the following config-file:
- DEEPRVAT_DIR/example/deeprvat_input_config.yaml
:param config_file: Path to directory of relevant config yaml file
:type config_file: str
:param output_dir: Path to directory where created deeprvat_config.yaml will be saved.
:type output_dir: str
:return: Joined configuration file saved to deeprvat_config.yaml.
"""
#Base Config
with open(f'{os.path.dirname(deeprvat_dir.__file__)}/base_configurations.yaml') as f:
base_config = yaml.safe_load(f)

expected_input_keys = [
"phenotypes_for_association_testing",
"phenotypes_for_training",
"gt_filename",
"variant_filename",
"phenotype_filename",
"annotation_filename",
"gene_filename",
"rare_variant_annotations",
"covariates",
"association_testing_data_thresholds",
"training_data_thresholds",
"seed_gene_results",
"pl_trainer",
"early_stopping",
"n_repeats",
"y_transformation",
"cv_exp",
"cv_path",
"n_folds",
]

full_config = base_config

with open(config_file) as f:
input_config = yaml.safe_load(f)

#CV setup parameters
if not input_config["cv_exp"]:
logger.info("Not CV setup...removing CV pipeline parameters from config")
to_remove = {"cv_path","n_folds"}
expected_input_keys = [item for item in expected_input_keys if item not in to_remove]
full_config["cv_exp"] = False
else:
full_config["cv_path"] = input_config["cv_path"]
full_config["n_folds"] = input_config["n_folds"]
full_config["cv_exp"] = True

no_pretrain = True
if "use_pretrained_models" in input_config:
if input_config["use_pretrained_models"]:
no_pretrain = False
logger.info("Pretrained Model setup specified.")
to_remove = {"pl_trainer","early_stopping"}
expected_input_keys = [item for item in expected_input_keys if item not in to_remove]
expected_input_keys.extend(["use_pretrained_models","model"])

with open(f'{os.path.dirname(pretrained_dir.__file__)}/config.yaml') as f:
pretrained_config = yaml.safe_load(f)

for k in pretrained_config:
input_config[k] = deepcopy(pretrained_config[k])


if set(input_config.keys()) - set(expected_input_keys):
extra_keys=set(input_config.keys()) - set(expected_input_keys)
raise KeyError(("Unspecified key present in input YAML file. "
f"The follow extra keys are present: {extra_keys} "
"Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."))

# Phenotypes
full_config["phenotypes"] = input_config["phenotypes_for_association_testing"]
full_config["training"]["phenotypes"] = input_config["phenotypes_for_training"]
full_config["training_data"]["dataset_config"]["y_transformation"] = input_config["y_transformation"]
full_config["assocation_testing_data"]["dataset_config"]["y_transformation"] = input_config["y_transformation"]
# genotypes.h5
full_config["training_data"]["gt_file"] = input_config["gt_filename"]
full_config["assocation_testing_data"]["gt_file"] = input_config["gt_filename"]
# variants.parquet
full_config["training_data"]["variant_file"] = input_config["variant_filename"]
full_config["assocation_testing_data"]["variant_file"] = input_config["variant_filename"]
# phenotypes.parquet
full_config["training_data"]["dataset_config"]["phenotype_file"] = input_config["phenotype_filename"]
full_config["assocation_testing_data"]["dataset_config"]["phenotype_file"] = input_config["phenotype_filename"]
# annotations.parquet
full_config["training_data"]["dataset_config"]["annotation_file"] = input_config["annotation_filename"]
full_config["assocation_testing_data"]["dataset_config"]["annotation_file"] = input_config["annotation_filename"]
# protein_coding_genes.parquet
full_config["assocation_testing_data"]["dataset_config"]["gene_file"] = input_config["gene_filename"]
full_config["assocation_testing_data"]["dataset_config"]["rare_embedding"]["config"]["gene_file"] = input_config["gene_filename"]
# rare_variant_annotations
full_config["training_data"]["dataset_config"]["rare_embedding"]["config"]["annotations"] = input_config["rare_variant_annotations"]
full_config["assocation_testing_data"]["dataset_config"]["rare_embedding"]["config"]["annotations"] = input_config["rare_variant_annotations"]
# variant annotations
anno_list = deepcopy(input_config["rare_variant_annotations"])
for i,k in enumerate(input_config["training_data_thresholds"].keys()):
anno_list.insert(i+1,k)
full_config["training_data"]["dataset_config"]["annotations"] = anno_list
full_config["assocation_testing_data"]["dataset_config"]["annotations"] = anno_list
# covariates
full_config["training_data"]["dataset_config"]["x_phenotypes"] = input_config["covariates"]
full_config["assocation_testing_data"]["dataset_config"]["x_phenotypes"] = input_config["covariates"]
# Thresholds
full_config["training_data"]["dataset_config"]["rare_embedding"]["config"]["thresholds"] = {}
full_config["assocation_testing_data"]["dataset_config"]["rare_embedding"]["config"]["thresholds"] = {}
for k,v in input_config["training_data_thresholds"].items():
full_config["training_data"]["dataset_config"]["rare_embedding"]["config"]["thresholds"][k] = f"{k} {v}"
for k,v in input_config["association_testing_data_thresholds"].items():
full_config["assocation_testing_data"]["dataset_config"]["rare_embedding"]["config"]["thresholds"][k] = f"{k} {v}"
# Baseline results
full_config["baseline_results"]["options"] = input_config["seed_gene_results"]["options"]
full_config["alpha"] = input_config["seed_gene_results"]["alpha"]
#DeepRVAT model
full_config["n_repeats"] = input_config["n_repeats"]

full_config["data"] = full_config["assocation_testing_data"]
del full_config["assocation_testing_data"]

if no_pretrain:
# PL trainer
full_config["pl_trainer"] = input_config["pl_trainer"]
# Early Stopping
full_config["early_stopping"] = input_config["early_stopping"]
else:
full_config["model"] = input_config["model"]

with open(f"{os.path.dirname(pretrained_dir.__file__)}/deeprvat_config.yaml", "w") as f:
yaml.dump(full_config, f)


with open(f"{output_dir}/deeprvat_config.yaml", "w") as f:
yaml.dump(full_config, f)

@cli.command()
@click.option("--association-only", is_flag=True)
Expand Down Expand Up @@ -86,7 +234,7 @@ def update_config(
"specified if --baseline-results is"
)
seed_config = config["phenotypes"][phenotype]
correction_method = seed_config.get("correction_method", None)
correction_method = config["baseline_results"].get("correction_method", None)
min_seed_genes = seed_config.get("min_seed_genes", 3)
max_seed_genes = seed_config.get("max_seed_genes", None)
threshold = seed_config.get("pvalue_threshold", None)
Expand Down
12 changes: 12 additions & 0 deletions pipelines/setup_deeprvat_config.snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@


rule create_main_deeprvat_config:
input:
config_file = 'deeprvat_input_pretrained_models_config.yaml' #'deeprvat/example/config/deeprvat_input_config.yaml',
output:
'deeprvat_config.yaml'
shell:
(
"deeprvat_config create-main-config "
"{input.config_file} "
)
File renamed without changes.
40 changes: 40 additions & 0 deletions pretrained_models/config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
rare_variant_annotations:
- MAF_MB
- CADD_raw
- sift_score
- polyphen_score
- Consequence_splice_acceptor_variant
- Consequence_splice_donor_variant
- Consequence_stop_gained
- Consequence_frameshift_variant
- Consequence_stop_lost
- Consequence_start_lost
- Consequence_inframe_insertion
- Consequence_inframe_deletion
- Consequence_missense_variant
- Consequence_protein_altering_variant
- Consequence_splice_region_variant
- condel_score
- DeepSEA_PC_1
- DeepSEA_PC_2
- DeepSEA_PC_3
- DeepSEA_PC_4
- DeepSEA_PC_5
- DeepSEA_PC_6
- PrimateAI_score
- AbSplice_DNA
- DeepRipe_plus_QKI_lip_hg2
- DeepRipe_plus_QKI_clip_k5
- DeepRipe_plus_KHDRBS1_clip_k5
- DeepRipe_plus_ELAVL1_parclip
- DeepRipe_plus_TARDBP_parclip
- DeepRipe_plus_HNRNPD_parclip
- DeepRipe_plus_MBNL1_parclip
- DeepRipe_plus_QKI_parclip
- SpliceAI_delta_score
- alphamissense

training_data_thresholds:
MAF: "< 1e-2"
CADD_PHRED: "> 5"

model:
checkpoint: combined_agg.pt
config:
Expand Down

0 comments on commit c9d027e

Please sign in to comment.