Skip to content

Commit

Permalink
Initial snakemake file
Browse files Browse the repository at this point in the history
  • Loading branch information
endast committed May 2, 2024
1 parent 71d8d7a commit 3d0899c
Showing 1 changed file with 73 additions and 60 deletions.
133 changes: 73 additions & 60 deletions pipelines/association_testing/burdens.snakefile
Original file line number Diff line number Diff line change
@@ -1,103 +1,116 @@
rule average_burdens:
input:
chunks = [
(f'{p}/deeprvat/burdens/chunk{c}.' +
("finished" if p == phenotypes[0] else "linked"))
for p in phenotypes
for c in range(n_burden_chunks)
] if not cv_exp else '{phenotype}/deeprvat/burdens/merging.finished'
burdens='{phenotype}/deeprvat/burdens/burdens.zarr',
x='{phenotype}/deeprvat/burdens/x.zarr',
y='{phenotype}/deeprvat/burdens/y.zarr',
sample_ids='{phenotype}/deeprvat/burdens/sample_ids.zarr',
output:
'{phenotype}/deeprvat/burdens/logs/burdens_averaging_{chunk}.finished',
params:
burdens_in = '{phenotype}/deeprvat/burdens/burdens.zarr',
burdens_out = '{phenotype}/deeprvat/burdens/burdens_average.zarr',
repeats = lambda wildcards: ''.join([f'--repeats {r} ' for r in range(int(n_repeats))])
burdens_in='{phenotype}/deeprvat/burdens/burdens.zarr',
burdens_out='{phenotype}/deeprvat/burdens/burdens_average.zarr',
repeats=lambda wildcards: ''.join([f'--repeats {r} ' for r in range(int(n_repeats))])
threads: 1
resources:
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
mem_mb=lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
priority: 10,
shell:
' && '.join([
('deeprvat_associate average-burdens '
'--n-chunks '+ str(n_avg_chunks) + ' '
'--chunk {wildcards.chunk} '
'{params.repeats} '
'--agg-fct mean ' #TODO remove this
'{params.burdens_in} '
'{params.burdens_out}'),
'--n-chunks ' + str(n_avg_chunks) + ' '
'--chunk {wildcards.chunk} '
'{params.repeats} '
'--agg-fct mean ' #TODO remove this
'{params.burdens_in} '
'{params.burdens_out}'),
'touch {output}'
])

rule link_burdens:
priority: 1
input:
checkpoints = lambda wildcards: [
checkpoints=lambda wildcards: [
f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt'
for repeat in range(n_repeats) for bag in range(n_bags)
],
dataset = '{phenotype}/deeprvat/association_dataset.pkl',
data_config = '{phenotype}/deeprvat/hpopt_config.yaml',
model_config = model_path / 'config.yaml',
dataset='{phenotype}/deeprvat/association_dataset.pkl',
data_config='{phenotype}/deeprvat/hpopt_config.yaml',
model_config=model_path / 'config.yaml',
output:
'{phenotype}/deeprvat/burdens/chunk{chunk}.linked'
params:
prefix = '.'
prefix='.'
threads: 8
resources:
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
mem_mb=lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
shell:
' && '.join([
('deeprvat_associate compute-burdens '
+ debug +
' --n-chunks '+ str(n_burden_chunks) + ' '
f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'),
' --n-chunks ' + str(n_burden_chunks) + ' '
f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'),
'touch {output}'
])

rule compute_burdens:
priority: 10
input:
reversed = model_path / "reverse_finished.tmp",
checkpoints = lambda wildcards: [
reversed=model_path / "reverse_finished.tmp",
checkpoints=lambda wildcards: [
f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt'
for repeat in range(n_repeats) for bag in range(n_bags)
],
dataset = '{phenotype}/deeprvat/association_dataset.pkl',
data_config = '{phenotype}/deeprvat/hpopt_config.yaml',
model_config = model_path / 'config.yaml',
dataset='{phenotype}/deeprvat/association_dataset.pkl',
data_config='{phenotype}/deeprvat/hpopt_config.yaml',
model_config=model_path / 'config.yaml',
output:
'{phenotype}/deeprvat/burdens/chunk{chunk}.finished'
burdens='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/burdens.zarr',
x='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/x.zarr',
y='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/y.zarr',
sample_ids='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/sample_ids.zarr',
params:
prefix = '.'
prefix='.'
threads: 8
resources:
mem_mb = 20000,
gpus = 1
mem_mb=20000,
gpus=1
shell:
' && '.join([
('deeprvat_associate compute-burdens '
+ debug +
' --n-chunks '+ str(n_burden_chunks) + ' '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'),
'touch {output}'])
' '.join([
'deeprvat_associate compute-burdens '
+ debug +
' --n-chunks ' + str(n_burden_chunks) + ' '
'--chunk {wildcards.chunk} '
'--dataset-file {input.dataset} '
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens'],
)

rule combine_burdens:
input:
burdens='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/burdens.zarr',
x='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/x.zarr',
y='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/y.zarr',
sample_ids='{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/sample_ids.zarr',
expand(
'{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/burdens.zarr',
chunk=[c for c in range(n_burden_chunks)],
phenotype=phenotypes),
expand(
'{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/x.zarr',
chunk=[c for c in range(n_burden_chunks)],
phenotype=phenotypes),
expand(
'{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/y.zarr',
chunk=[c for c in range(n_burden_chunks)],
phenotype=phenotypes),
expand(
'{phenotype}/deeprvat/burdens/chunks/chunk{chunk}/sample_ids.zarr',
chunk=[c for c in range(n_burden_chunks)],
phenotype=phenotypes)
output:
burdens='{phenotype}/deeprvat/burdens/burdens.zarr',
x='{phenotype}/deeprvat/burdens/x.zarr',
Expand All @@ -107,23 +120,23 @@ rule combine_burdens:
prefix='.'
shell:
' '.join([
"'{phenotype}/deeprvat/burdens/chunks/",
"'{wildcards.phenotype}/deeprvat/burdens/chunks/",
'deeprvat_associate combine-burden-chunks',
' --n-chunks ' + str(n_burden_chunks),
'{params.prefix}/{wildcards.phenotype}/deeprvat/burdens',
])

rule reverse_models:
input:
checkpoints = expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt',
bag=range(n_bags), repeat=range(n_repeats)),
model_config = model_path / 'config.yaml',
data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml",
checkpoints=expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt',
bag=range(n_bags),repeat=range(n_repeats)),
model_config=model_path / 'config.yaml',
data_config=Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml",
output:
model_path / "reverse_finished.tmp"
threads: 4
resources:
mem_mb = 20480,
mem_mb=20480,
shell:
" && ".join([
("deeprvat_associate reverse-models "
Expand Down

0 comments on commit 3d0899c

Please sign in to comment.