diff --git a/pipelines/training_association_testing.snakefile b/pipelines/training_association_testing.snakefile index 069602b6..70429f0b 100644 --- a/pipelines/training_association_testing.snakefile +++ b/pipelines/training_association_testing.snakefile @@ -14,6 +14,7 @@ n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' tensor_compression_level = config['training'].get('tensor_compression_level', 1) +n_parallel_training_jobs = config["training"].get("n_parallel_jobs") wildcard_constraints: repeat="\d+", @@ -235,8 +236,10 @@ rule train: y = expand('{phenotype}/deeprvat/y.zarr', phenotype=phenotypes), output: - config = 'models/repeat_{repeat}/trial{trial_number}/config.yaml', - finished = 'models/repeat_{repeat}/trial{trial_number}/finished.tmp' + expand('models/repeat_{repeat}/trial{trial_number}/config.yaml', + repeat=range(n_repeats), trial_number=range(n_trials)), + expand('models/repeat_{repeat}/trial{trial_number}/finished.tmp', + repeat=range(n_repeats), trial_number=range(n_trials))) params: phenotypes = " ".join( [f"--phenotype {p} " @@ -245,16 +248,17 @@ rule train: f"{p}/deeprvat/y.zarr" for p in phenotypes]) shell: - ' && '.join([ - 'deeprvat_train train ' - + debug + - '--trial-id {wildcards.trial_number} ' - "{params.phenotypes} " - 'config.yaml ' - 'models/repeat_{wildcards.repeat}/trial{wildcards.trial_number} ' - 'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db', - 'touch {output.finished}' - ]) + f"parallel --jobs {n_parallel_training_jobs} --results train_repeat{{1}}_trial{{2}}/ " + 'deeprvat_train train ' + + debug + + '--trial-id {{2}} ' + "{params.phenotypes} " + 'config.yaml ' + 'models/repeat_{{1}}/trial{{2}} ' + "models/repeat_{{1}}/hyperparameter_optimization.db '&&' " + "touch models/repeat_{{1}}/trial{{2}}/finished.tmp " + "::: " + " ".join(range(n_repeats)) + " " + "::: " + " ".join(range(n_trials)) rule all_training_dataset: input: