Skip to content

Commit

Permalink
train multiple repeats on single node in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Nov 22, 2023
1 parent 6761a90 commit 684eb23
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions pipelines/training_association_testing.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ n_repeats = config['n_repeats']
debug = '--debug ' if debug_flag else ''
do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else ''
tensor_compression_level = config['training'].get('tensor_compression_level', 1)
n_parallel_training_jobs = config["training"].get("n_parallel_jobs")

wildcard_constraints:
repeat="\d+",
Expand Down Expand Up @@ -235,8 +236,10 @@ rule train:
y = expand('{phenotype}/deeprvat/y.zarr',
phenotype=phenotypes),
output:
config = 'models/repeat_{repeat}/trial{trial_number}/config.yaml',
finished = 'models/repeat_{repeat}/trial{trial_number}/finished.tmp'
expand('models/repeat_{repeat}/trial{trial_number}/config.yaml',
repeat=range(n_repeats), trial_number=range(n_trials)),
expand('models/repeat_{repeat}/trial{trial_number}/finished.tmp',
repeat=range(n_repeats), trial_number=range(n_trials)))
params:
phenotypes = " ".join(
[f"--phenotype {p} "
Expand All @@ -245,16 +248,17 @@ rule train:
f"{p}/deeprvat/y.zarr"
for p in phenotypes])
shell:
' && '.join([
'deeprvat_train train '
+ debug +
'--trial-id {wildcards.trial_number} '
"{params.phenotypes} "
'config.yaml '
'models/repeat_{wildcards.repeat}/trial{wildcards.trial_number} '
'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db',
'touch {output.finished}'
])
f"parallel --jobs {n_parallel_training_jobs} --results train_repeat{{1}}_trial{{2}}/ "
'deeprvat_train train '
+ debug +
'--trial-id {{2}} '
"{params.phenotypes} "
'config.yaml '
'models/repeat_{{1}}/trial{{2}} '
"models/repeat_{{1}}/hyperparameter_optimization.db '&&' "
"touch models/repeat_{{1}}/trial{{2}}/finished.tmp "
"::: " + " ".join(range(n_repeats)) + " "
"::: " + " ".join(range(n_trials))

rule all_training_dataset:
input:
Expand Down

0 comments on commit 684eb23

Please sign in to comment.