Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Feature streamline cv training" #79

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions deeprvat/cv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ def spread_config(
cv_path = f"{config_template['cv_path']}/{n_folds}_fold"
for module in data_modules:
config = copy.deepcopy(config_template)
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file
data_slots = DATA_SLOT_DICT[module]
for data_slot in data_slots:
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config[data_slot]["dataset_config"]["sample_file"] = sample_file

if (module == "deeprvat") | (module == "deeprvat_pretrained"):
logger.info("Writing baseline directories")
Expand Down Expand Up @@ -89,7 +91,8 @@ def generate_test_config(input_config, out_file, fold, n_folds):
split = "test"
sample_file = f"{cv_path}/samples_{split}{fold}.pkl"
logger.info(f"setting sample file {sample_file}")
config["sample_file"] = sample_file
for data_slot in DATA_SLOT_DICT["deeprvat"]:
config[data_slot]["dataset_config"]["sample_file"] = sample_file
with open(out_file, "w") as f:
yaml.dump(config, f)

Expand Down
22 changes: 8 additions & 14 deletions deeprvat/data/dense_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def setup_phenotypes(
# account for the fact that genotypes.h5 and phenotype_df can have different
# orders of their samples
self.index_map_geno, _ = get_matched_sample_indices(
samples_gt.astype(str), self.samples.astype(str)
samples_gt.astype(int), self.samples.astype(int)
)
# get_matched_sample_indices is a much, much faster implementation of the code below
# self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)]
Expand Down Expand Up @@ -614,20 +614,13 @@ def setup_variants(
"Annotation dataframe has inconsistent allele frequency values"
)
variants_with_af = safe_merge(
variants[["id"]].reset_index(drop=True), af_annotation, how="left"
variants[["id"]].reset_index(drop=True), af_annotation
)
assert np.all(
variants_with_af["id"].to_numpy() == variants["id"].to_numpy()
)
af_isna = variants_with_af[af_col].isna()
if af_isna.sum() > 0:
logger.warning(
f"Dropping {af_isna.sum()} variants missing from annotation dataframe"
)
mask = (
(~af_isna)
& (variants_with_af[af_col] >= af_threshold)
& (variants_with_af[af_col] <= 1 - af_threshold)
mask = (variants_with_af[af_col] >= af_threshold) & (
variants_with_af[af_col] <= 1 - af_threshold
)
mask = mask.to_numpy()
del variants_with_af
Expand Down Expand Up @@ -938,10 +931,11 @@ def get_metadata(self) -> Dict[str, Any]:
result = {
"variant_metadata": self.variants[
["id", "common_variant_mask", "rare_variant_mask", "matrix_index"]
],
"samples": self.samples,
]
}
if self.use_rare_variants:
if hasattr(self.rare_embedding, "get_metadata"):
result["rare_embedding_metadata"] = self.rare_embedding.get_metadata()
result.update(
{"rare_embedding_metadata": self.rare_embedding.get_metadata()}
)
return result
60 changes: 22 additions & 38 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple

import click
import dask.dataframe as dd
Expand Down Expand Up @@ -115,7 +115,9 @@ def cli():

def make_dataset_(
config: Dict,
debug: bool = False,
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
Expand Down Expand Up @@ -147,17 +149,29 @@ def make_dataset_(
**copy.deepcopy(data_config["dataset_config"]),
)

restrict_samples = config.get("restrict_samples", None)
if debug:
logger.info("Debug flag set; Using only 1000 samples")
ds = Subset(ds, range(1_000))
elif samples is not None:
ds = Subset(ds, samples)
elif restrict_samples is not None:
ds = Subset(ds, range(restrict_samples))

return ds


@cli.command()
@click.option("--debug", is_flag=True)
@click.option("--data-key", type=str, default="data")
@click.argument("config-file", type=click.Path(exists=True, path_type=Path))
@click.argument("out-file", type=click.Path(path_type=Path))
def make_dataset(data_key: str, config_file: Path, out_file: Path):
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
"""
Create a dataset based on the provided configuration and save to a pickle file.

:param debug: Flag for debugging.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param config_file: Path to the configuration file.
Expand All @@ -169,7 +183,7 @@ def make_dataset(data_key: str, config_file: Path, out_file: Path):
with open(config_file) as f:
config = yaml.safe_load(f)

ds = make_dataset_(config, data_key=data_key)
ds = make_dataset_(config, debug=debug, data_key=data_key)

with open(out_file, "wb") as f:
pickle.dump(ds, f)
Expand Down Expand Up @@ -222,8 +236,6 @@ def compute_burdens_(
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
logger.setLevel(logging.INFO)

if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand All @@ -235,38 +247,10 @@ def compute_burdens_(

data_config = config["data"]

if "sample_file" in config:
sample_file = Path(config["sample_file"])
logger.info(f"Using samples from {sample_file}")
if sample_file.suffix == ".pkl":
with open(sample_file, "rb") as f:
sample_ids = np.array(pickle.load(f))
elif sample_file.suffix == ".zarr":
sample_ids = zarr.load(sample_file)
elif sample_file.suffix == ".npy":
sample_ids = np.load(sample_file)
else:
raise ValueError("Unknown file type for sample_file")
ds_samples = ds.get_metadata()["samples"]
sample_indices = np.where(
np.isin(ds_samples.astype(str), sample_ids.astype(str))
)[0]
if debug:
sample_indices = sample_indices[:1000]
elif debug:
sample_indices = np.arange(min(1000, len(ds)))
else:
sample_indices = np.arange(len(ds))

logger.info(
f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}"
)
ds = Subset(ds, sample_indices)

ds_full = ds.dataset # if isinstance(ds, Subset) else ds
ds_full = ds.dataset if isinstance(ds, Subset) else ds
collate_fn = getattr(ds_full, "collate_fn", None)
n_total_samples = len(ds)
ds_full.rare_embedding.skip_embedding = skip_burdens
ds.rare_embedding.skip_embedding = skip_burdens

if chunk is not None:
if n_chunks is None:
Expand Down Expand Up @@ -919,7 +903,7 @@ def compute_burdens(
with open(dataset_file, "rb") as f:
dataset = pickle.load(f)
else:
dataset = make_dataset_(data_config)
dataset = make_dataset_(config)

if torch.cuda.is_available():
logger.info("Using GPU")
Expand Down
Loading