Skip to content

Commit

Permalink
Feature streamline cv training (#71)
Browse files Browse the repository at this point in the history
* reduce number of jobs necessary for CV training

* update tests for changes in dataset class

* fixup! Format Python code with psf/black pull_request

---------

Co-authored-by: Brian Clarke <brian.clarke@dkfz.de>
Co-authored-by: PMBio <PMBio@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent 7179ce9 commit d9ede28
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 129 deletions.
11 changes: 4 additions & 7 deletions deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ def spread_config(
cv_path = f"{config_template['cv_path']}/{n_folds}_fold"
for module in data_modules:
config = copy.deepcopy(config_template)
data_slots = DATA_SLOT_DICT[module]
for data_slot in data_slots:
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config[data_slot]["dataset_config"]["sample_file"] = sample_file
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file

if (module == "deeprvat") | (module == "deeprvat_pretrained"):
logger.info("Writing baseline directories")
Expand Down Expand Up @@ -91,8 +89,7 @@ def generate_test_config(input_config, out_file, fold, n_folds):
split = "test"
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
for data_slot in DATA_SLOT_DICT["deeprvat"]:
config[data_slot]["dataset_config"]["sample_file"] = sample_file
config["sample_file"] = sample_file
with open(out_file, "w") as f:
yaml.dump(config, f)

Expand Down
22 changes: 14 additions & 8 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def setup_phenotypes(
# account for the fact that genotypes.h5 and phenotype_df can have different
# orders of their samples
self.index_map_geno, _ = get_matched_sample_indices(
samples_gt.astype(int), self.samples.astype(int)
samples_gt.astype(str), self.samples.astype(str)
)
# get_matched_sample_indices is a much, much faster implementation of the code below
# self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)]
Expand Down Expand Up @@ -614,13 +614,20 @@ def setup_variants(
"Annotation dataframe has inconsistent allele frequency values"
)
variants_with_af = safe_merge(
variants[["id"]].reset_index(drop=True), af_annotation
variants[["id"]].reset_index(drop=True), af_annotation, how="left"
)
assert np.all(
variants_with_af["id"].to_numpy() == variants["id"].to_numpy()
)
mask = (variants_with_af[af_col] >= af_threshold) & (
variants_with_af[af_col] <= 1 - af_threshold
af_isna = variants_with_af[af_col].isna()
if af_isna.sum() > 0:
logger.warning(
f"Dropping {af_isna.sum()} variants missing from annotation dataframe"
)
mask = (
(~af_isna)
& (variants_with_af[af_col] >= af_threshold)
& (variants_with_af[af_col] <= 1 - af_threshold)
)
mask = mask.to_numpy()
del variants_with_af
Expand Down Expand Up @@ -931,11 +938,10 @@ def get_metadata(self) -> Dict[str, Any]:
result = {
"variant_metadata": self.variants[
["id", "common_variant_mask", "rare_variant_mask", "matrix_index"]
]
],
"samples": self.samples,
}
if self.use_rare_variants:
if hasattr(self.rare_embedding, "get_metadata"):
result.update(
{"rare_embedding_metadata": self.rare_embedding.get_metadata()}
)
result["rare_embedding_metadata"] = self.rare_embedding.get_metadata()
return result
60 changes: 38 additions & 22 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import click
import dask.dataframe as dd
Expand Down Expand Up @@ -115,9 +115,7 @@ def cli():

def make_dataset_(
config: Dict,
debug: bool = False,
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
Expand Down Expand Up @@ -149,29 +147,17 @@ def make_dataset_(
**copy.deepcopy(data_config["dataset_config"]),
)

restrict_samples = config.get("restrict_samples", None)
if debug:
logger.info("Debug flag set; Using only 1000 samples")
ds = Subset(ds, range(1_000))
elif samples is not None:
ds = Subset(ds, samples)
elif restrict_samples is not None:
ds = Subset(ds, range(restrict_samples))

return ds


@cli.command()
@click.option("--debug", is_flag=True)
@click.option("--data-key", type=str, default="data")
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
@click.argument("config-file", type=click.Path(exists=True, path_type=Path))
@click.argument("out-file", type=click.Path(path_type=Path))
def make_dataset(data_key: str, config_file: Path, out_file: Path):
"""
Create a dataset based on the provided configuration and save to a pickle file.
:param debug: Flag for debugging.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param config_file: Path to the configuration file.
Expand All @@ -183,7 +169,7 @@ def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
with open(config_file) as f:
config = yaml.safe_load(f)

ds = make_dataset_(config, debug=debug, data_key=data_key)
ds = make_dataset_(config, data_key=data_key)

with open(out_file, "wb") as f:
pickle.dump(ds, f)
Expand Down Expand Up @@ -236,6 +222,8 @@ def compute_burdens_(
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
logger.setLevel(logging.INFO)

if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand All @@ -247,10 +235,38 @@ def compute_burdens_(

data_config = config["data"]

ds_full = ds.dataset if isinstance(ds, Subset) else ds
if "sample_file" in config:
sample_file = Path(config["sample_file"])
logger.info(f"Using samples from {sample_file}")
if sample_file.suffix == ".pkl":
with open(sample_file, "rb") as f:
sample_ids = np.array(pickle.load(f))
elif sample_file.suffix == ".zarr":
sample_ids = zarr.load(sample_file)
elif sample_file.suffix == ".npy":
sample_ids = np.load(sample_file)
else:
raise ValueError("Unknown file type for sample_file")
ds_samples = ds.get_metadata()["samples"]
sample_indices = np.where(
np.isin(ds_samples.astype(str), sample_ids.astype(str))
)[0]
if debug:
sample_indices = sample_indices[:1000]
elif debug:
sample_indices = np.arange(min(1000, len(ds)))
else:
sample_indices = np.arange(len(ds))

logger.info(
f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}"
)
ds = Subset(ds, sample_indices)

ds_full = ds.dataset # if isinstance(ds, Subset) else ds
collate_fn = getattr(ds_full, "collate_fn", None)
n_total_samples = len(ds)
ds.rare_embedding.skip_embedding = skip_burdens
ds_full.rare_embedding.skip_embedding = skip_burdens

if chunk is not None:
if n_chunks is None:
Expand Down Expand Up @@ -903,7 +919,7 @@ def compute_burdens(
with open(dataset_file, "rb") as f:
dataset = pickle.load(f)
else:
dataset = make_dataset_(config)
dataset = make_dataset_(data_config)

if torch.cuda.is_available():
logger.info("Using GPU")
Expand Down
Loading

0 comments on commit d9ede28

Please sign in to comment.