Skip to content

Commit

Permalink
Add in Snakemake Log-Files (#147)
Browse files Browse the repository at this point in the history
* remove double logging

* add in logging redirct

* add in logging redirct
-remove logging from train rule (already performed with parallel)
- remove params.prefix bug from log paths

* bugfix logging redirect for associate.py
- add in logging handlers.clear()
- define logging level

* output additional fields in model_config.yaml , to be used for pretrained_models setup

* adding in logging directive to cv pipeline

* add in log to final regenie pipeline

* fixup! Format Python code with psf/black pull_request

---------

Co-authored-by: PMBio <PMBio@users.noreply.github.com>
  • Loading branch information
meyerkm and PMBio authored Dec 4, 2024
1 parent 8d8e0cb commit 38c09db
Show file tree
Hide file tree
Showing 31 changed files with 240 additions and 58 deletions.
2 changes: 1 addition & 1 deletion deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/data/rare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
23 changes: 12 additions & 11 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from tqdm import tqdm, trange
import zarr
import re

import deeprvat.deeprvat.models as deeprvat_models
from deeprvat.data import DenseGTDataset

logging.root.handlers.clear() # Remove all handlers associated with the root logger object
logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level=logging.INFO,
Expand All @@ -48,6 +48,11 @@
AGG_FCT = {"mean": np.mean, "max": np.max}


@click.group()
def cli():
pass


def get_burden(
batch: Dict,
agg_models: Dict[str, List[nn.Module]],
Expand Down Expand Up @@ -99,11 +104,6 @@ def separate_parallel_results(results: List) -> Tuple[List, ...]:
return tuple(map(list, zip(*results)))


@click.group()
def cli():
pass


def make_dataset_(
config: Dict,
debug: bool = False,
Expand Down Expand Up @@ -306,7 +306,6 @@ def make_regenie_input_(
gene_metadata_file: Path,
gtf: Path,
):
logger.setLevel(logging.INFO)

## Check options
if not skip_burdens and burdens_genes_samples is None:
Expand Down Expand Up @@ -420,7 +419,7 @@ def make_regenie_input_(
if average_repeats:
logger.info("Averaging burdens across all repeats")
burdens = np.zeros((n_samples, n_genes))
for repeat in trange(burdens_zarr.shape[2]):
for repeat in trange(burdens_zarr.shape[2], file=sys.stdout):
burdens += burdens_zarr[:n_samples, :, repeat]
burdens = burdens / burdens_zarr.shape[2]
else:
Expand Down Expand Up @@ -448,7 +447,7 @@ def make_regenie_input_(
n_samples,
samples=list(sample_ids.astype(str)),
) as f:
for i in trange(n_genes):
for i in trange(n_genes, file=sys.stdout):
varid = f"pseudovariant_gene_{ensgids[i]}"
this_burdens = burdens[:, i] # Rescale scores to be in range (0, 2)
genotypes = np.stack(
Expand Down Expand Up @@ -746,7 +745,7 @@ def load_models(
}

if len(checkpoint_files[first_repeat]) > 1:
logging.info(
logger.info(
f" Averaging results from {len(checkpoint_files[first_repeat])} models for each repeat"
)

Expand Down Expand Up @@ -1064,7 +1063,9 @@ def combine_burden_chunks_(
end_id = 0

for i, chunk in tqdm(
enumerate(range(0, n_chunks)), desc=f"Merging {n_chunks} chunks"
enumerate(range(0, n_chunks)),
desc=f"Merging {n_chunks} chunks",
file=sys.stdout,
):
chunk_dir = burdens_chunks_dir / f"chunk_{chunk}"

Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/common_variant_condition_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
29 changes: 21 additions & 8 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pprint import pformat, pprint
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union

import re
import click
import math
import numpy as np
Expand Down Expand Up @@ -37,10 +37,9 @@
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm


logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -872,20 +871,20 @@ def run_bagging(
trainer.fit(model, dm)
except RuntimeError as e:
# if batch_size is choosen to big, it will be reduced until it fits the GPU
logging.error(f"Caught RuntimeError: {e}")
logger.error(f"Caught RuntimeError: {e}")
if str(e).find("CUDA out of memory") != -1:
if dm.hparams.batch_size > 4:
logging.error(
logger.error(
"Retrying training with half the original batch size"
)
gc.collect()
torch.cuda.empty_cache()
dm.hparams.batch_size = dm.hparams.batch_size // 2
else:
logging.error("Batch size is already <= 4, giving up")
logger.error("Batch size is already <= 4, giving up")
raise RuntimeError("Could not find small enough batch size")
else:
logging.error(f"Caught unknown error: {e}")
logger.error(f"Caught unknown error: {e}")
raise e
else:
break
Expand Down Expand Up @@ -1167,7 +1166,21 @@ def best_training_run(
config = yaml.safe_load(f)

with open(config_file_out, "w") as f:
yaml.dump({"model": config["model"]}, f)
yaml.dump(
{
"model": config["model"],
"rare_variant_annotations": config["training_data"]["dataset_config"][
"rare_embedding"
]["config"]["annotations"],
"training_data_thresholds": {
k: str(re.sub(f"^{k} ", "", v))
for k, v in config["training_data"]["dataset_config"][
"rare_embedding"
]["config"]["thresholds"].items()
},
},
f,
)

n_bags = config["training"]["n_bags"] if not debug else 3
for k in range(n_bags):
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/seed_gene_discovery/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/seed_gene_discovery/seed_gene_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion deeprvat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

logging.basicConfig(
format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
level="INFO",
level=logging.INFO,
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
Expand Down
12 changes: 10 additions & 2 deletions pipelines/association_testing/association_dataset.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ rule association_dataset:
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1),
priority: 30
log:
stdout="logs/association_dataset/{phenotype}.stdout",
stderr="logs/association_dataset/{phenotype}.stderr"
shell:
'deeprvat_associate make-dataset '
+ debug +
"--skip-genotypes "
'{input.data_config} '
'{output}'
'{output} '
+ logging_redirct


rule association_dataset_burdens:
Expand All @@ -33,8 +37,12 @@ rule association_dataset_burdens:
resources:
mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1)
priority: 30
log:
stdout=f"logs/association_dataset_burdens/{phenotypes[0]}.stdout",
stderr=f"logs/association_dataset_burdens/{phenotypes[0]}.stderr"
shell:
'deeprvat_associate make-dataset '
+ debug +
'{input.data_config} '
'{output}'
'{output} '
+ logging_redirct
27 changes: 22 additions & 5 deletions pipelines/association_testing/burdens.snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ rule combine_burdens:
threads: 1
resources:
mem_mb = lambda wildcards, attempt: 4098 + (attempt - 1) * 4098,
log:
stdout="logs/combine_burdens/combine_burdens.stdout",
stderr="logs/combine_burdens/combine_burdens.stderr"
shell:
' '.join([
'deeprvat_associate combine-burden-chunks',
'{params.prefix}/burdens/chunks/',
' --n-chunks ' + str(n_burden_chunks),
'{params.prefix}/burdens',
'{params.prefix}/burdens ',
logging_redirct
])

rule all_xy:
Expand All @@ -42,14 +46,18 @@ rule compute_xy:
threads: 8
resources:
mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098,
log:
stdout="logs/compute_xy/{phenotype}.stdout",
stderr="logs/compute_xy/{phenotype}.stderr"
shell:
' && '.join([
('deeprvat_associate compute-xy '
'--dataset-file {input.dataset} '
'{input.data_config} '
"{output.samples} "
"{output.x} "
"{output.y}")
"{output.y} "
+ logging_redirct)
])


Expand All @@ -73,6 +81,9 @@ rule compute_burdens:
resources:
mem_mb = 32000,
gpus = 1
log:
stdout="logs/compute_burdens/compute_burdens_{chunk}.stdout",
stderr="logs/compute_burdens/compute_burdens_{chunk}.stderr"
shell:
' '.join([
'deeprvat_associate compute-burdens '
Expand All @@ -83,7 +94,8 @@ rule compute_burdens:
'{input.data_config} '
'{input.model_config} '
'{input.checkpoints} '
'{params.prefix}/burdens'],
'{params.prefix}/burdens '
+ logging_redirct ],
)


Expand All @@ -98,11 +110,16 @@ rule reverse_models:
threads: 4
resources:
mem_mb = 20480,
log:
stdout="logs/reverse_models/reverse_models.stdout",
stderr="logs/reverse_models/reverse_models.stderr"
shell:
" && ".join([
("deeprvat_associate reverse-models "
"{input.model_config} "
"{input.data_config} "
"{input.checkpoints}"),
"touch {output}"
"{input.checkpoints} "
+ logging_redirct),
"touch {output} "

])
Loading

0 comments on commit 38c09db

Please sign in to comment.