From fc14c51a362418f11212f84ba49593b3b363c75c Mon Sep 17 00:00:00 2001 From: Kayla Meyer <129152803+meyerkm@users.noreply.github.com> Date: Fri, 22 Dec 2023 11:44:46 +0100 Subject: [PATCH] Feature snakemake modular (#41) * making snakemake runners modular Additional snakemake runners for only running training, only association testing, and full train + association testing pipelines. * bug-fix pretrained model path * Adding additional snakemake pipeline run option to readthedocs * update train snakefile pipeline from PR #42 * bug-fix model path for snakemake pipeline runners * bug-fix f string syntax * Update github-actions.yml * Update github-actions.yml * Update github-actions.yml * Revert "Update github-actions.yml" This reverts commit b8b82e568f823fa8030e244c7b2515b7ad9d3813. * Revert "Update github-actions.yml" This reverts commit d27e6a8b9d65561717536cd2111f8b91891cc558. * Revert "Update github-actions.yml" This reverts commit 6cde84f1188a7f3aae910c72b1c53fbd9c80f219. * Update github-actions.yml * Update github-actions.yml * fix-model path string variable in rules --------- Co-authored-by: Magnus Wahlberg --- docs/usage.md | 14 + .../association_dataset.snakefile | 12 + .../association_testing/burdens.snakefile | 74 +++++ .../regress_eval.snakefile | 63 ++++ .../association_testing_pretrained.snakefile | 188 +----------- pipelines/run_training.snakefile | 51 +++ pipelines/training/config.snakefile | 29 ++ pipelines/training/train.snakefile | 64 ++++ pipelines/training/training_dataset.snakefile | 37 +++ .../training_association_testing.snakefile | 290 +----------------- 10 files changed, 363 insertions(+), 459 deletions(-) create mode 100644 pipelines/association_testing/association_dataset.snakefile create mode 100644 pipelines/association_testing/burdens.snakefile create mode 100644 pipelines/association_testing/regress_eval.snakefile create mode 100644 pipelines/run_training.snakefile create mode 100644 pipelines/training/config.snakefile create mode 100644 pipelines/training/train.snakefile create mode 100644 pipelines/training/training_dataset.snakefile diff --git a/docs/usage.md b/docs/usage.md index 93361782..5d7c9170 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -56,6 +56,20 @@ Replace `[path_to_deeprvat]` with the path to your clone of the repository. Note that the example data is randomly generated, and so is only suited for testing whether the `deeprvat` package has been correctly installed. +### Run the training pipeline on some example data + +```shell +mkdir example +cd example +ln -s [path_to_deeprvat]/example/* . +snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/run_training.snakefile +``` + +Replace `[path_to_deeprvat]` with the path to your clone of the repository. + +Note that the example data is randomly generated, and so is only suited for testing whether the `deeprvat` package has been correctly installed. + + ### Run the association testing pipeline with pretrained models ```shell diff --git a/pipelines/association_testing/association_dataset.snakefile b/pipelines/association_testing/association_dataset.snakefile new file mode 100644 index 00000000..0e63e53f --- /dev/null +++ b/pipelines/association_testing/association_dataset.snakefile @@ -0,0 +1,12 @@ + +rule association_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/association_dataset.pkl' + threads: 4 + shell: + 'deeprvat_associate make-dataset ' + + debug + + '{input.config} ' + '{output}' diff --git a/pipelines/association_testing/burdens.snakefile b/pipelines/association_testing/burdens.snakefile new file mode 100644 index 00000000..550390fa --- /dev/null +++ b/pipelines/association_testing/burdens.snakefile @@ -0,0 +1,74 @@ + +rule link_burdens: + priority: 1 + input: + checkpoints = lambda wildcards: [ + f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = model_path / 'config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' + threads: 8 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule compute_burdens: + priority: 10 + input: + reversed = model_path / "reverse_finished.tmp", + checkpoints = lambda wildcards: [ + model_path / f'repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = model_path / 'config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' + threads: 8 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule reverse_models: + input: + checkpoints = expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + model_config = model_path / 'config.yaml', + data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", + output: + temp(model_path / "reverse_finished.tmp") + threads: 4 + shell: + " && ".join([ + ("deeprvat_associate reverse-models " + "{input.model_config} " + "{input.data_config} " + "{input.checkpoints}"), + "touch {output}" + ]) \ No newline at end of file diff --git a/pipelines/association_testing/regress_eval.snakefile b/pipelines/association_testing/regress_eval.snakefile new file mode 100644 index 00000000..bcb3f369 --- /dev/null +++ b/pipelines/association_testing/regress_eval.snakefile @@ -0,0 +1,63 @@ + +rule evaluate: + input: + associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + repeat=range(n_repeats)), + config = '{phenotype}/deeprvat/hpopt_config.yaml', + output: + "{phenotype}/deeprvat/eval/significant.parquet", + "{phenotype}/deeprvat/eval/all_results.parquet" + threads: 1 + shell: + 'deeprvat_evaluate ' + + debug + + '--use-seed-genes ' + '--n-repeats {n_repeats} ' + '--correction-method FDR ' + '{input.associations} ' + '{input.config} ' + '{wildcards.phenotype}/deeprvat/eval' + +rule all_regression: + input: + expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), + +rule combine_regression_chunks: + input: + expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), + output: + '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + threads: 1 + shell: + 'deeprvat_associate combine-regression-results ' + '--model-name repeat_{wildcards.repeat} ' + '{input} ' + '{output}' + +rule regress: + input: + config = "{phenotype}/deeprvat/hpopt_config.yaml", + chunks = lambda wildcards: expand( + ('{{phenotype}}/deeprvat/burdens/chunk{chunk}.' + + ("finished" if wildcards.phenotype == phenotypes[0] else "linked")), + chunk=range(n_burden_chunks) + ), + phenotype_0_chunks = expand( + phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', + chunk=range(n_burden_chunks) + ), + output: + temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), + threads: 2 + shell: + 'deeprvat_associate regress ' + + debug + + '--chunk {wildcards.chunk} ' + '--n-chunks ' + str(n_regression_chunks) + ' ' + '--use-bias ' + '--repeat {wildcards.repeat} ' + + do_scoretest + + '{input.config} ' + '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats + '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' \ No newline at end of file diff --git a/pipelines/association_testing_pretrained.snakefile b/pipelines/association_testing_pretrained.snakefile index 702302f0..d7aaa006 100644 --- a/pipelines/association_testing_pretrained.snakefile +++ b/pipelines/association_testing_pretrained.snakefile @@ -5,19 +5,27 @@ configfile: 'config.yaml' debug_flag = config.get('debug', False) phenotypes = config['phenotypes'] phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 +n_trials = config['hyperparameter_optimization']['n_trials'] n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' -pretrained_model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path(config.get("pretrained_model_path", "pretrained_models")) wildcard_constraints: repeat="\d+", trial="\d+", +include: "training/config.snakefile" +include: "association_testing/association_dataset.snakefile" +include: "association_testing/burdens.snakefile" +include: "association_testing/regress_eval.snakefile" + rule all: input: expand("{phenotype}/deeprvat/eval/significant.parquet", @@ -25,69 +33,6 @@ rule all: expand("{phenotype}/deeprvat/eval/all_results.parquet", phenotype=phenotypes) -rule evaluate: - input: - associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - repeat=range(n_repeats)), - config = '{phenotype}/deeprvat/hpopt_config.yaml', - output: - "{phenotype}/deeprvat/eval/significant.parquet", - "{phenotype}/deeprvat/eval/all_results.parquet" - threads: 1 - shell: - 'deeprvat_evaluate ' - + debug + - '--use-seed-genes ' - '--n-repeats {n_repeats} ' - '--correction-method FDR ' - '{input.associations} ' - '{input.config} ' - '{wildcards.phenotype}/deeprvat/eval' - -rule all_regression: - input: - expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), - -rule combine_regression_chunks: - input: - expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), - output: - '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - threads: 1 - shell: - 'deeprvat_associate combine-regression-results ' - '--model-name repeat_{wildcards.repeat} ' - '{input} ' - '{output}' - -rule regress: - input: - config = "{phenotype}/deeprvat/hpopt_config.yaml", - chunks = lambda wildcards: expand( - ('{{phenotype}}/deeprvat/burdens/chunk{chunk}.' + - ("finished" if wildcards.phenotype == phenotypes[0] else "linked")), - chunk=range(n_burden_chunks) - ), - phenotype_0_chunks = expand( - phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', - chunk=range(n_burden_chunks) - ), - output: - temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), - threads: 2 - shell: - 'deeprvat_associate regress ' - + debug + - '--chunk {wildcards.chunk} ' - '--n-chunks ' + str(n_regression_chunks) + ' ' - '--use-bias ' - '--repeat {wildcards.repeat} ' - + do_scoretest + - '{input.config} ' - '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats - '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' - rule all_burdens: input: [ @@ -97,97 +42,11 @@ rule all_burdens: for c in range(n_burden_chunks) ] -rule link_burdens: - priority: 1 - input: - checkpoints = lambda wildcards: [ - f'{pretrained_model_path}/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = pretrained_model_path / 'config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - -rule compute_burdens: - priority: 10 - input: - reversed = pretrained_model_path / "reverse_finished.tmp", - checkpoints = lambda wildcards: [ - pretrained_model_path / f'repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = pretrained_model_path / 'config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - rule all_association_dataset: input: expand('{phenotype}/deeprvat/association_dataset.pkl', phenotype=phenotypes) -rule association_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/association_dataset.pkl' - threads: 4 - shell: - 'deeprvat_associate make-dataset ' - + debug + - '{input.config} ' - '{output}' - -rule reverse_models: - input: - checkpoints = expand(pretrained_model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', - bag=range(n_bags), repeat=range(n_repeats)), - model_config = pretrained_model_path / 'config.yaml', - data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", - output: - temp(pretrained_model_path / "reverse_finished.tmp") - threads: 4 - shell: - " && ".join([ - ("deeprvat_associate reverse-models " - "{input.model_config} " - "{input.data_config} " - "{input.checkpoints}"), - "touch {output}" - ]) - rule all_config: input: seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', @@ -196,32 +55,3 @@ rule all_config: phenotype=phenotypes), baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', phenotype=phenotypes), - -rule config: - input: - config = 'config.yaml', - baseline = lambda wildcards: [ - str(Path(r['base']) / wildcards.phenotype / r['type'] / - 'eval/burden_associations.parquet') - for r in config['baseline_results'] - ] - output: - seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', - config = '{phenotype}/deeprvat/hpopt_config.yaml', - baseline = '{phenotype}/deeprvat/baseline_results.parquet', - threads: 1 - params: - baseline_results = lambda wildcards, input: ''.join([ - f'--baseline-results {b} ' - for b in input.baseline - ]) - shell: - ( - 'deeprvat_config update-config ' - '--phenotype {wildcards.phenotype} ' - '{params.baseline_results}' - '--baseline-results-out {output.baseline} ' - '--seed-genes-out {output.seed_genes} ' - '{input.config} ' - '{output.config}' - ) diff --git a/pipelines/run_training.snakefile b/pipelines/run_training.snakefile new file mode 100644 index 00000000..0e10d79e --- /dev/null +++ b/pipelines/run_training.snakefile @@ -0,0 +1,51 @@ +from pathlib import Path + +configfile: 'config.yaml' + +debug_flag = config.get('debug', False) +phenotypes = config['phenotypes'] +phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) + +n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 +n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 +n_trials = config['hyperparameter_optimization']['n_trials'] +n_bags = config['training']['n_bags'] if not debug_flag else 3 +n_repeats = config['n_repeats'] +debug = '--debug ' if debug_flag else '' +do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' +tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path("models") +n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) + +wildcard_constraints: + repeat="\d+", + trial="\d+", + +include: "training/config.snakefile" +include: "training/training_dataset.snakefile" +include: "training/train.snakefile" + +rule all: + input: + expand( model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + model_path / "config.yaml" + +rule all_training_dataset: + input: + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)) + +rule all_config: + input: + seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', + phenotype=phenotypes), + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=phenotypes), + baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', + phenotype=phenotypes), \ No newline at end of file diff --git a/pipelines/training/config.snakefile b/pipelines/training/config.snakefile new file mode 100644 index 00000000..3c58a39d --- /dev/null +++ b/pipelines/training/config.snakefile @@ -0,0 +1,29 @@ + +rule config: + input: + config = 'config.yaml', + baseline = lambda wildcards: [ + str(Path(r['base']) / wildcards.phenotype / r['type'] / + 'eval/burden_associations.parquet') + for r in config['baseline_results'] + ] + output: + seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', + config = '{phenotype}/deeprvat/hpopt_config.yaml', + baseline = '{phenotype}/deeprvat/baseline_results.parquet', + threads: 1 + params: + baseline_results = lambda wildcards, input: ''.join([ + f'--baseline-results {b} ' + for b in input.baseline + ]) + shell: + ( + 'deeprvat_config update-config ' + '--phenotype {wildcards.phenotype} ' + '{params.baseline_results}' + '--baseline-results-out {output.baseline} ' + '--seed-genes-out {output.seed_genes} ' + '{input.config} ' + '{output.config}' + ) \ No newline at end of file diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile new file mode 100644 index 00000000..c747fd1f --- /dev/null +++ b/pipelines/training/train.snakefile @@ -0,0 +1,64 @@ + +rule link_config: + input: + model_path / 'repeat_0/config.yaml' + output: + model_path / 'config.yaml' + threads: 1 + shell: + "ln -s repeat_0/config.yaml {output}" + + +rule best_training_run: + input: + expand(model_path / 'repeat_{{repeat}}/trial{trial_number}/config.yaml', + trial_number=range(n_trials)), + output: + checkpoints = expand(model_path / 'repeat_{{repeat}}/best/bag_{bag}.ckpt', + bag=range(n_bags)), + config = model_path / 'repeat_{repeat}/config.yaml' + threads: 1 + shell: + ( + 'deeprvat_train best-training-run ' + + debug + + '{model_path}/repeat_{wildcards.repeat} ' + '{model_path}/repeat_{wildcards.repeat}/best ' + '{model_path}/repeat_{wildcards.repeat}/hyperparameter_optimization.db ' + '{output.config}' + ) + +rule train: + input: + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=training_phenotypes), + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes), + output: + expand(model_path / 'repeat_{repeat}/trial{trial_number}/config.yaml', + repeat=range(n_repeats), trial_number=range(n_trials)), + expand(model_path / 'repeat_{repeat}/trial{trial_number}/finished.tmp', + repeat=range(n_repeats), trial_number=range(n_trials)) + params: + phenotypes = " ".join( + [f"--phenotype {p} " + f"{p}/deeprvat/input_tensor.zarr " + f"{p}/deeprvat/covariates.zarr " + f"{p}/deeprvat/y.zarr" + for p in training_phenotypes]) + shell: + f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " + 'deeprvat_train train ' + + debug + + '--trial-id {{2}} ' + "{params.phenotypes} " + 'config.yaml ' + '{model_path}/repeat_{{1}}/trial{{2}} ' + '{model_path}/repeat_{{1}}/hyperparameter_optimization.db "&&" ' + 'touch {model_path}/repeat_{{1}}/trial{{2}}/finished.tmp ' + "::: " + " ".join(map(str, range(n_repeats))) + " " + "::: " + " ".join(map(str, range(n_trials))) diff --git a/pipelines/training/training_dataset.snakefile b/pipelines/training/training_dataset.snakefile new file mode 100644 index 00000000..66903b85 --- /dev/null +++ b/pipelines/training/training_dataset.snakefile @@ -0,0 +1,37 @@ + +rule training_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml', + training_dataset = '{phenotype}/deeprvat/training_dataset.pkl' + output: + input_tensor = directory('{phenotype}/deeprvat/input_tensor.zarr'), + covariates = directory('{phenotype}/deeprvat/covariates.zarr'), + y = directory('{phenotype}/deeprvat/y.zarr') + threads: 8 + priority: 50 + shell: + ( + 'deeprvat_train make-dataset ' + + debug + + '--compression-level ' + str(tensor_compression_level) + ' ' + '--training-dataset-file {input.training_dataset} ' + '{input.config} ' + '{output.input_tensor} ' + '{output.covariates} ' + '{output.y}' + ) + +rule training_dataset_pickle: + input: + '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/training_dataset.pkl' + threads: 1 + shell: + ( + 'deeprvat_train make-dataset ' + '--pickle-only ' + '--training-dataset-file {output} ' + '{input} ' + 'dummy dummy dummy' + ) \ No newline at end of file diff --git a/pipelines/training_association_testing.snakefile b/pipelines/training_association_testing.snakefile index 1e887e1e..60384eaf 100644 --- a/pipelines/training_association_testing.snakefile +++ b/pipelines/training_association_testing.snakefile @@ -15,12 +15,20 @@ n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path("models") n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) wildcard_constraints: repeat="\d+", trial="\d+", +include: "training/config.snakefile" +include: "training/training_dataset.snakefile" +include: "training/train.snakefile" +include: "association_testing/association_dataset.snakefile" +include: "association_testing/burdens.snakefile" +include: "association_testing/regress_eval.snakefile" + rule all: input: expand("{phenotype}/deeprvat/eval/significant.parquet", @@ -28,69 +36,6 @@ rule all: expand("{phenotype}/deeprvat/eval/all_results.parquet", phenotype=phenotypes) -rule evaluate: - input: - associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - repeat=range(n_repeats)), - config = '{phenotype}/deeprvat/hpopt_config.yaml', - output: - "{phenotype}/deeprvat/eval/significant.parquet", - "{phenotype}/deeprvat/eval/all_results.parquet" - threads: 1 - shell: - 'deeprvat_evaluate ' - + debug + - '--use-seed-genes ' - '--n-repeats {n_repeats} ' - '--correction-method FDR ' - '{input.associations} ' - '{input.config} ' - '{wildcards.phenotype}/deeprvat/eval' - -rule all_regression: - input: - expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), - -rule combine_regression_chunks: - input: - expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), - output: - '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - threads: 1 - shell: - 'deeprvat_associate combine-regression-results ' - '--model-name repeat_{wildcards.repeat} ' - '{input} ' - '{output}' - -rule regress: - input: - config = "{phenotype}/deeprvat/hpopt_config.yaml", - chunks = lambda wildcards: ( - [] if wildcards.phenotype == phenotypes[0] - else expand('{{phenotype}}/deeprvat/burdens/chunk{chunk}.linked', - chunk=range(n_burden_chunks)) - ), - phenotype_0_chunks = expand( - phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', - chunk=range(n_burden_chunks) - ), - output: - temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), - threads: 2 - shell: - 'deeprvat_associate regress ' - + debug + - '--chunk {wildcards.chunk} ' - '--n-chunks ' + str(n_regression_chunks) + ' ' - '--use-bias ' - '--repeat {wildcards.repeat} ' - + do_scoretest + - '{input.config} ' - '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats - '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' - rule all_burdens: input: [ @@ -100,165 +45,16 @@ rule all_burdens: for c in range(n_burden_chunks) ] -rule link_burdens: - priority: 1 - input: - checkpoints = lambda wildcards: [ - f'models/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = 'models/config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - -rule compute_burdens: - priority: 10 - input: - reversed = "models/reverse_finished.tmp", - checkpoints = lambda wildcards: [ - f'models/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = 'models/config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - rule all_association_dataset: input: expand('{phenotype}/deeprvat/association_dataset.pkl', phenotype=phenotypes) -rule association_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/association_dataset.pkl' - threads: 4 - shell: - 'deeprvat_associate make-dataset ' - + debug + - '{input.config} ' - '{output}' - -rule reverse_models: - input: - checkpoints = expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', - bag=range(n_bags), repeat=range(n_repeats)), - model_config = 'models/config.yaml', - data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", - output: - "models/reverse_finished.tmp" - threads: 4 - shell: - " && ".join([ - ("deeprvat_associate reverse-models " - "{input.model_config} " - "{input.data_config} " - "{input.checkpoints}"), - "touch {output}" - ]) - rule all_training: input: - expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', + expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', bag=range(n_bags), repeat=range(n_repeats)), - "models/config.yaml" - -rule link_config: - input: - 'models/repeat_0/config.yaml' - output: - "models/config.yaml" - threads: 1 - shell: - "ln -s repeat_0/config.yaml {output}" - -rule best_training_run: - input: - expand('models/repeat_{{repeat}}/trial{trial_number}/config.yaml', - trial_number=range(n_trials)), - output: - checkpoints = expand('models/repeat_{{repeat}}/best/bag_{bag}.ckpt', - bag=range(n_bags)), - config = 'models/repeat_{repeat}/config.yaml' - threads: 1 - shell: - ( - 'deeprvat_train best-training-run ' - + debug + - 'models/repeat_{wildcards.repeat} ' - 'models/repeat_{wildcards.repeat}/best ' - 'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db ' - '{output.config}' - ) - -rule train: - input: - config = expand('{phenotype}/deeprvat/hpopt_config.yaml', - phenotype=training_phenotypes), - input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', - phenotype=training_phenotypes), - covariates = expand('{phenotype}/deeprvat/covariates.zarr', - phenotype=training_phenotypes), - y = expand('{phenotype}/deeprvat/y.zarr', - phenotype=training_phenotypes), - output: - expand('models/repeat_{repeat}/trial{trial_number}/config.yaml', - repeat=range(n_repeats), trial_number=range(n_trials)), - expand('models/repeat_{repeat}/trial{trial_number}/finished.tmp', - repeat=range(n_repeats), trial_number=range(n_trials)) - params: - phenotypes = " ".join( - [f"--phenotype {p} " - f"{p}/deeprvat/input_tensor.zarr " - f"{p}/deeprvat/covariates.zarr " - f"{p}/deeprvat/y.zarr" - for p in training_phenotypes]) - shell: - f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " - 'deeprvat_train train ' - + debug + - '--trial-id {{2}} ' - "{params.phenotypes} " - 'config.yaml ' - 'models/repeat_{{1}}/trial{{2}} ' - "models/repeat_{{1}}/hyperparameter_optimization.db '&&' " - "touch models/repeat_{{1}}/trial{{2}}/finished.tmp " - "::: " + " ".join(map(str, range(n_repeats))) + " " - "::: " + " ".join(map(str, range(n_trials))) + model_path / "config.yaml" rule all_training_dataset: input: @@ -269,43 +65,6 @@ rule all_training_dataset: y = expand('{phenotype}/deeprvat/y.zarr', phenotype=training_phenotypes, repeat=range(n_repeats)) -rule training_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml', - training_dataset = '{phenotype}/deeprvat/training_dataset.pkl' - output: - input_tensor = directory('{phenotype}/deeprvat/input_tensor.zarr'), - covariates = directory('{phenotype}/deeprvat/covariates.zarr'), - y = directory('{phenotype}/deeprvat/y.zarr') - threads: 8 - priority: 50 - shell: - ( - 'deeprvat_train make-dataset ' - + debug + - '--compression-level ' + str(tensor_compression_level) + ' ' - '--training-dataset-file {input.training_dataset} ' - '{input.config} ' - '{output.input_tensor} ' - '{output.covariates} ' - '{output.y}' - ) - -rule training_dataset_pickle: - input: - '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/training_dataset.pkl' - threads: 1 - shell: - ( - 'deeprvat_train make-dataset ' - '--pickle-only ' - '--training-dataset-file {output} ' - '{input} ' - 'dummy dummy dummy' - ) - rule all_config: input: seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', @@ -314,32 +73,3 @@ rule all_config: phenotype=phenotypes), baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', phenotype=phenotypes), - -rule config: - input: - config = 'config.yaml', - baseline = lambda wildcards: [ - str(Path(r['base']) / wildcards.phenotype / r['type'] / - 'eval/burden_associations.parquet') - for r in config['baseline_results'] - ] - output: - seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', - config = '{phenotype}/deeprvat/hpopt_config.yaml', - baseline = '{phenotype}/deeprvat/baseline_results.parquet', - threads: 1 - params: - baseline_results = lambda wildcards, input: ''.join([ - f'--baseline-results {b} ' - for b in input.baseline - ]) - shell: - ( - 'deeprvat_config update-config ' - '--phenotype {wildcards.phenotype} ' - '{params.baseline_results}' - '--baseline-results-out {output.baseline} ' - '--seed-genes-out {output.seed_genes} ' - '{input.config} ' - '{output.config}' - )