diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index 0df41734..4e026161 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -119,8 +119,11 @@ def create_main_config( ) ) - if no_pretrain: - if all(key not in input_config["training"] for key in ["pl_trainer", "early_stopping"]): + if no_pretrain: + if all( + key not in input_config["training"] + for key in ["pl_trainer", "early_stopping"] + ): raise KeyError( "Missing keys pl_trainer and/or early_stopping under config['training'] " "Please review DEEPRVAT_DIR/example/config/deeprvat_input_config.yaml for list of keys." @@ -203,7 +206,9 @@ def create_main_config( ]["thresholds"][k] = f"{k} {v}" # Results evaluation parameters; alpha parameter for significance threshold full_config["evaluation"] = {} - full_config["evaluation"]["correction_method"] = input_config["evaluation"]["correction_method"] + full_config["evaluation"]["correction_method"] = input_config["evaluation"][ + "correction_method" + ] full_config["evaluation"]["alpha"] = input_config["evaluation"]["alpha"] # DeepRVAT model full_config["n_repeats"] = input_config["n_repeats"] @@ -215,14 +220,18 @@ def create_main_config( # PL trainer full_config["training"]["pl_trainer"] = input_config["training"]["pl_trainer"] # Early Stopping - full_config["training"]["early_stopping"] = input_config["training"]["early_stopping"] + full_config["training"]["early_stopping"] = input_config["training"][ + "early_stopping" + ] # Training Phenotypes full_config["training"]["phenotypes"] = input_config["phenotypes_for_training"] # Baseline results full_config["baseline_results"]["options"] = input_config["seed_gene_results"][ "options" ] - full_config["baseline_results"]["alpha_seed_genes"] = input_config["seed_gene_results"]["alpha_seed_genes"] + full_config["baseline_results"]["alpha_seed_genes"] = input_config[ + "seed_gene_results" + ]["alpha_seed_genes"] else: full_config["model"] = input_config["model"] @@ -324,7 +333,9 @@ def update_config( else: logger.info("Not performing EAC filtering of baseline results") logger.info(f" Correcting p-values using {correction_method} method") - alpha = config["baseline_results"].get("alpha_seed_genes", config.get("alpha")) + alpha = config["baseline_results"].get( + "alpha_seed_genes", config.get("alpha") + ) baseline_df = pval_correction( baseline_df, alpha, correction_type=correction_method ) diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index ad2d6349..b3978aae 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -849,7 +849,9 @@ def run_bagging( # initialize trainer, which will call background functionality trainer = pl.Trainer( - logger=tb_logger, callbacks=callbacks, **config["training"].get("pl_trainer", {}) + logger=tb_logger, + callbacks=callbacks, + **config["training"].get("pl_trainer", {}), ) while True: