Skip to content

Commit

Permalink
Merge branch '132-config-maf-thresholds' of https://github.com/PMBio/…
Browse files Browse the repository at this point in the history
…deeprvat into 132-config-maf-thresholds
  • Loading branch information
meyerkm committed Sep 25, 2024
2 parents 1430706 + d190eb0 commit e8ca8ef
Showing 1 changed file with 161 additions and 70 deletions.
231 changes: 161 additions & 70 deletions deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ def cli():
def setup_logging(log_filename: str = "config_generate.log"):
file_handler = logging.FileHandler(log_filename, mode="a")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
file_handler.setFormatter(
logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s")
)
logger.addHandler(file_handler)
return file_handler


def load_yaml(file_path: str):
with open(file_path) as f:
return yaml.safe_load(f)


def update_defaults(base_config, input_config):
"""
Updates base_config with values from input_config, for intersecting nested keys.
Expand All @@ -59,43 +63,64 @@ def update_defaults(base_config, input_config):

return base_config


def handle_cv_options(input_config, full_config, expected_input_keys):
if input_config.get("cv_options", {}).get("cv_exp", False):
missing_keys = [key for key in ["cv_exp", "cv_path", "n_folds"] if key not in input_config["cv_options"]]
missing_keys = [
key
for key in ["cv_exp", "cv_path", "n_folds"]
if key not in input_config["cv_options"]
]
if missing_keys:
raise KeyError(f"Missing keys {missing_keys} under config['cv_options'] \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys.")
full_config.update({
"cv_path": input_config["cv_options"]["cv_path"],
"n_folds": input_config["cv_options"]["n_folds"],
"cv_exp": True
})
raise KeyError(
f"Missing keys {missing_keys} under config['cv_options'] \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
)
full_config.update(
{
"cv_path": input_config["cv_options"]["cv_path"],
"n_folds": input_config["cv_options"]["n_folds"],
"cv_exp": True,
}
)
else:
logger.info("Not CV setup...removing CV pipeline parameters from config")
full_config["cv_exp"] = False
expected_input_keys.remove("cv_options")
input_config.pop("cv_options", None)


def handle_regenie_options(input_config, full_config, expected_input_keys):
if input_config.get("regenie_options", {}).get("regenie_exp", False):
missing_keys = [key for key in ["regenie_exp", "step_1", "step_2"] if key not in input_config["regenie_options"]]
missing_keys = [
key
for key in ["regenie_exp", "step_1", "step_2"]
if key not in input_config["regenie_options"]
]
if missing_keys:
raise KeyError(f"Missing keys {missing_keys} under config['regenie_options'] \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys.")
full_config.update({
"regenie_exp": True,
"regenie_options": {
"step_1": input_config["regenie_options"]["step_1"],
"step_2": input_config["regenie_options"]["step_2"],
},
"gtf_file": input_config["regenie_options"]["gtf_file"]
})
raise KeyError(
f"Missing keys {missing_keys} under config['regenie_options'] \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
)
full_config.update(
{
"regenie_exp": True,
"regenie_options": {
"step_1": input_config["regenie_options"]["step_1"],
"step_2": input_config["regenie_options"]["step_2"],
},
"gtf_file": input_config["regenie_options"]["gtf_file"],
}
)
else:
logger.info("Not using REGENIE integration...removing REGENIE parameters from config")
logger.info(
"Not using REGENIE integration...removing REGENIE parameters from config"
)
full_config["regenie_exp"] = False
expected_input_keys.remove("regenie_options")
input_config.pop("regenie_options", None)


def handle_pretrained_models(input_config, expected_input_keys):
if input_config.get("use_pretrained_models", False):
logger.info("Pretrained Model setup specified.")
Expand All @@ -104,86 +129,132 @@ def handle_pretrained_models(input_config, expected_input_keys):
expected_input_keys.remove(item)

pretrained_model_path = Path(input_config["pretrained_model_path"])
expected_input_keys.extend(["use_pretrained_models", "model", "pretrained_model_path"])
expected_input_keys.extend(
["use_pretrained_models", "model", "pretrained_model_path"]
)

pretrained_config = load_yaml(f"{pretrained_model_path}/model_config.yaml")

required_keys = {"model", "rare_variant_annotations", "training_data_thresholds"}
required_keys = {
"model",
"rare_variant_annotations",
"training_data_thresholds",
}
extra_keys = set(pretrained_config.keys()) - required_keys
if extra_keys:
raise KeyError(f"Unexpected key in pretrained_model_path/model_config.yaml file : {extra_keys} \n\
Please review DEEPRVAT_DIR/pretrained_models/model_config.yaml for expected list of keys.")
raise KeyError(
f"Unexpected key in pretrained_model_path/model_config.yaml file : {extra_keys} \n\
Please review DEEPRVAT_DIR/pretrained_models/model_config.yaml for expected list of keys."
)
logger.info(" Updating input config with keys from pretrained model config.")
input_config.update({
"model": pretrained_config["model"],
"rare_variant_annotations": pretrained_config["rare_variant_annotations"],
"training_data_thresholds": pretrained_config["training_data_thresholds"]
})
input_config.update(
{
"model": pretrained_config["model"],
"rare_variant_annotations": pretrained_config[
"rare_variant_annotations"
],
"training_data_thresholds": pretrained_config[
"training_data_thresholds"
],
}
)
return True
return False


def update_thresholds(input_config, full_config, train_only):
if "MAF" not in input_config["training_data_thresholds"]:
raise KeyError(f"Missing required MAF threshold in config['training_data_thresholds']")
if not train_only and "MAF" not in input_config["association_testing_data_thresholds"]:
raise KeyError(f"Missing required MAF threshold in config['association_testing_data_thresholds']")

raise KeyError(
f"Missing required MAF threshold in config['training_data_thresholds']"
)
if (
not train_only
and "MAF" not in input_config["association_testing_data_thresholds"]
):
raise KeyError(
f"Missing required MAF threshold in config['association_testing_data_thresholds']"
)

datasets = ["training_data", "association_testing_data"]
if train_only: datasets.remove("association_testing_data")
if train_only:
datasets.remove("association_testing_data")

for data_type in datasets:
anno_list = deepcopy(input_config["rare_variant_annotations"])
full_config[data_type]["dataset_config"]["rare_embedding"]["config"]["thresholds"] = {}
full_config[data_type]["dataset_config"]["rare_embedding"]["config"][
"thresholds"
] = {}
threshold_key = f"{data_type}_thresholds"
for i, (k, v) in enumerate(input_config[threshold_key].items()):
full_config[data_type]["dataset_config"]["rare_embedding"]["config"]["thresholds"][k] = f"{k} {v}"
full_config[data_type]["dataset_config"]["rare_embedding"]["config"][
"thresholds"
][k] = f"{k} {v}"
anno_list.insert(i + 1, k)
if k == "MAF":
full_config[data_type]["dataset_config"]["min_common_af"]["MAF"] = float(v[2:]) #v is string like "< 1e-3"
full_config[data_type]["dataset_config"]["min_common_af"]["MAF"] = (
float(v[2:])
) # v is string like "< 1e-3"
full_config[data_type]["dataset_config"]["annotations"] = anno_list


def update_full_config(input_config, full_config, train_only):
base_mapping = {
"gt_filename": "gt_file", # genotypes.h5
"variant_filename": "variant_file"
"variant_filename": "variant_file",
}
dataset_mapping = {
"phenotype_filename": "phenotype_file", # phenotypes.parquet
"annotation_filename": "annotation_file", # annotations.parquet
"covariates": "x_phenotypes"
"phenotype_filename": "phenotype_file", # phenotypes.parquet
"annotation_filename": "annotation_file", # annotations.parquet
"covariates": "x_phenotypes",
}

for key, value in base_mapping.items():
full_config["training_data"][value] = input_config[key]
if not train_only:
full_config["association_testing_data"][value] = input_config[key]

for key, value in dataset_mapping.items():
full_config["training_data"]["dataset_config"][value] = input_config[key]
if not train_only:
full_config["association_testing_data"]["dataset_config"][value] = input_config[key]

full_config["training_data"]["dataset_config"]["rare_embedding"]["config"]["annotations"] = input_config["rare_variant_annotations"]
full_config["association_testing_data"]["dataset_config"]["gene_file"] = input_config["gene_filename"] # protein_coding_genes.parquet
full_config["association_testing_data"]["dataset_config"][value] = (
input_config[key]
)

full_config["training_data"]["dataset_config"]["rare_embedding"]["config"][
"annotations"
] = input_config["rare_variant_annotations"]
full_config["association_testing_data"]["dataset_config"]["gene_file"] = (
input_config["gene_filename"]
) # protein_coding_genes.parquet
if not train_only:
full_config["phenotypes"] = input_config["phenotypes_for_association_testing"]
full_config["association_testing_data"]["dataset_config"]["rare_embedding"]["config"]["gene_file"] = input_config["gene_filename"]
full_config["association_testing_data"]["dataset_config"]["rare_embedding"]["config"]["annotations"] = input_config["rare_variant_annotations"]
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
"config"
]["gene_file"] = input_config["gene_filename"]
full_config["association_testing_data"]["dataset_config"]["rare_embedding"][
"config"
]["annotations"] = input_config["rare_variant_annotations"]


def validate_keys(input_config, expected_input_keys, optional_input_keys, base_config):
input_keys_set = set(input_config.keys()) - set(optional_input_keys)
expected_keys_set = set(expected_input_keys)
updated_base_keys = set(base_config.keys()).intersection(input_config.keys())

extra_keys = input_keys_set - expected_keys_set - updated_base_keys
missing_keys = expected_keys_set - input_keys_set

if extra_keys:
raise KeyError(f"Extra key(s) present in input YAML file: {extra_keys} \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys.")
raise KeyError(
f"Extra key(s) present in input YAML file: {extra_keys} \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
)
if missing_keys:
raise KeyError(f"Missing key(s) in input YAML file: {missing_keys} \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys.")
raise KeyError(
f"Missing key(s) in input YAML file: {missing_keys} \n\
Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys."
)


def create_main_config(
config_file: str,
Expand Down Expand Up @@ -257,7 +328,11 @@ def create_main_config(

train_only = input_config.pop("training_only", False)
if train_only:
to_remove = ["phenotypes_for_association_testing", "association_testing_data_thresholds", "evaluation"]
to_remove = [
"phenotypes_for_association_testing",
"association_testing_data_thresholds",
"evaluation",
]
for item in to_remove:
expected_input_keys.remove(item)

Expand All @@ -269,14 +344,20 @@ def create_main_config(
if train_only:
raise KeyError("Must specify phenotypes_for_training in config file!")
logger.info(
" Setting training phenotypes to be the same set as specified by phenotypes_for_association_testing."
)
input_config["phenotypes_for_training"] = input_config["phenotypes_for_association_testing"]
" Setting training phenotypes to be the same set as specified by phenotypes_for_association_testing."
)
input_config["phenotypes_for_training"] = input_config[
"phenotypes_for_association_testing"
]

if "y_transformation" in input_config:
full_config["training_data"]["dataset_config"]["y_transformation"] = input_config["y_transformation"]
full_config["training_data"]["dataset_config"]["y_transformation"] = (
input_config["y_transformation"]
)
if not train_only:
full_config["association_testing_data"]["dataset_config"]["y_transformation"] = input_config["y_transformation"]
full_config["association_testing_data"]["dataset_config"][
"y_transformation"
] = input_config["y_transformation"]
else:
expected_input_keys.remove("y_transformation")

Expand All @@ -290,34 +371,44 @@ def create_main_config(
if "sample_files" in input_config:
for key in ["training", "association_testing"]:
if key in input_config["sample_files"]:
full_config[f"{key}_data"]["dataset_config"]["sample_file"] = input_config["sample_files"][key]
full_config[f"{key}_data"]["dataset_config"]["sample_file"] = (
input_config["sample_files"][key]
)

# Results evaluation parameters; alpha parameter for significance threshold
if not train_only:
full_config["evaluation"] = {
"correction_method": input_config["evaluation"]["correction_method"],
"alpha": input_config["evaluation"]["alpha"]
"alpha": input_config["evaluation"]["alpha"],
}

if pretrained_setup:
full_config.update({
"model": input_config["model"],
"pretrained_model_path": input_config["pretrained_model_path"]
})
full_config.update(
{
"model": input_config["model"],
"pretrained_model_path": input_config["pretrained_model_path"],
}
)
else:
full_config["training"]["pl_trainer"] = input_config["training"]["pl_trainer"]
full_config["training"]["early_stopping"] = input_config["training"]["early_stopping"]
full_config["training"]["phenotypes"] = {pheno: {} for pheno in input_config["phenotypes_for_training"]}
full_config["training"]["early_stopping"] = input_config["training"][
"early_stopping"
]
full_config["training"]["phenotypes"] = {
pheno: {} for pheno in input_config["phenotypes_for_training"]
}
# For each phenotype, you can optionally specify dictionary of = {"min_seed_genes": 3, "max_seed_genes": None, "pvalue_threshold": None}
full_config["baseline_results"] = {
"options": input_config["seed_gene_results"]["result_dirs"],
"alpha_seed_genes": input_config["seed_gene_results"]["alpha_seed_genes"],
"correction_method": input_config["seed_gene_results"]["correction_method"]
"correction_method": input_config["seed_gene_results"]["correction_method"],
}

with open(output_path, "w") as f:
yaml.dump(full_config, f)
logger.info(f"Saving deeprvat_config.yaml to -- {output_dir}/deeprvat_config.yaml --")
logger.info(
f"Saving deeprvat_config.yaml to -- {output_dir}/deeprvat_config.yaml --"
)

logger.removeHandler(file_handler)
file_handler.close()
Expand Down

0 comments on commit e8ca8ef

Please sign in to comment.