diff --git a/deeprvat/cv_utils.py b/deeprvat/cv_utils.py index 7c97271e..68a285e5 100644 --- a/deeprvat/cv_utils.py +++ b/deeprvat/cv_utils.py @@ -59,9 +59,11 @@ def spread_config( cv_path = f"{config_template['cv_path']}/{n_folds}_fold" for module in data_modules: config = copy.deepcopy(config_template) - sample_file = f"{cv_path}/samples_{split}{fold}.pkl" - logger.info(f"setting sample file {sample_file}") - config["sample_file"] = sample_file + data_slots = DATA_SLOT_DICT[module] + for data_slot in data_slots: + sample_file = f"{cv_path}/samples_{split}{fold}.pkl" + logger.info(f"setting sample file {sample_file}") + config[data_slot]["dataset_config"]["sample_file"] = sample_file if (module == "deeprvat") | (module == "deeprvat_pretrained"): logger.info("Writing baseline directories") @@ -89,7 +91,8 @@ def generate_test_config(input_config, out_file, fold, n_folds): split = "test" sample_file = f"{cv_path}/samples_{split}{fold}.pkl" logger.info(f"setting sample file {sample_file}") - config["sample_file"] = sample_file + for data_slot in DATA_SLOT_DICT["deeprvat"]: + config[data_slot]["dataset_config"]["sample_file"] = sample_file with open(out_file, "w") as f: yaml.dump(config, f) diff --git a/deeprvat/data/dense_gt.py b/deeprvat/data/dense_gt.py index 7ebdaea9..12247518 100644 --- a/deeprvat/data/dense_gt.py +++ b/deeprvat/data/dense_gt.py @@ -362,7 +362,7 @@ def setup_phenotypes( # account for the fact that genotypes.h5 and phenotype_df can have different # orders of their samples self.index_map_geno, _ = get_matched_sample_indices( - samples_gt.astype(str), self.samples.astype(str) + samples_gt.astype(int), self.samples.astype(int) ) # get_matched_sample_indices is a much, much faster implementation of the code below # self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)] @@ -614,20 +614,13 @@ def setup_variants( "Annotation dataframe has inconsistent allele frequency values" ) variants_with_af = safe_merge( - variants[["id"]].reset_index(drop=True), af_annotation, how="left" + variants[["id"]].reset_index(drop=True), af_annotation ) assert np.all( variants_with_af["id"].to_numpy() == variants["id"].to_numpy() ) - af_isna = variants_with_af[af_col].isna() - if af_isna.sum() > 0: - logger.warning( - f"Dropping {af_isna.sum()} variants missing from annotation dataframe" - ) - mask = ( - (~af_isna) - & (variants_with_af[af_col] >= af_threshold) - & (variants_with_af[af_col] <= 1 - af_threshold) + mask = (variants_with_af[af_col] >= af_threshold) & ( + variants_with_af[af_col] <= 1 - af_threshold ) mask = mask.to_numpy() del variants_with_af @@ -938,10 +931,11 @@ def get_metadata(self) -> Dict[str, Any]: result = { "variant_metadata": self.variants[ ["id", "common_variant_mask", "rare_variant_mask", "matrix_index"] - ], - "samples": self.samples, + ] } if self.use_rare_variants: if hasattr(self.rare_embedding, "get_metadata"): - result["rare_embedding_metadata"] = self.rare_embedding.get_metadata() + result.update( + {"rare_embedding_metadata": self.rare_embedding.get_metadata()} + ) return result diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 03404067..5af2e770 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from pprint import pprint -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import click import dask.dataframe as dd @@ -115,7 +115,9 @@ def cli(): def make_dataset_( config: Dict, + debug: bool = False, data_key="data", + samples: Optional[List[int]] = None, ) -> Dataset: """ Create a dataset based on the configuration. @@ -147,17 +149,29 @@ def make_dataset_( **copy.deepcopy(data_config["dataset_config"]), ) + restrict_samples = config.get("restrict_samples", None) + if debug: + logger.info("Debug flag set; Using only 1000 samples") + ds = Subset(ds, range(1_000)) + elif samples is not None: + ds = Subset(ds, samples) + elif restrict_samples is not None: + ds = Subset(ds, range(restrict_samples)) + return ds @cli.command() +@click.option("--debug", is_flag=True) @click.option("--data-key", type=str, default="data") -@click.argument("config-file", type=click.Path(exists=True, path_type=Path)) -@click.argument("out-file", type=click.Path(path_type=Path)) -def make_dataset(data_key: str, config_file: Path, out_file: Path): +@click.argument("config-file", type=click.Path(exists=True)) +@click.argument("out-file", type=click.Path()) +def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): """ Create a dataset based on the provided configuration and save to a pickle file. + :param debug: Flag for debugging. + :type debug: bool :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". :type data_key: str :param config_file: Path to the configuration file. @@ -169,7 +183,7 @@ def make_dataset(data_key: str, config_file: Path, out_file: Path): with open(config_file) as f: config = yaml.safe_load(f) - ds = make_dataset_(config, data_key=data_key) + ds = make_dataset_(config, debug=debug, data_key=data_key) with open(out_file, "wb") as f: pickle.dump(ds, f) @@ -222,8 +236,6 @@ def compute_burdens_( .. note:: Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ - logger.setLevel(logging.INFO) - if not skip_burdens: logger.info("agg_models[*][*].reverse:") pprint( @@ -235,38 +247,10 @@ def compute_burdens_( data_config = config["data"] - if "sample_file" in config: - sample_file = Path(config["sample_file"]) - logger.info(f"Using samples from {sample_file}") - if sample_file.suffix == ".pkl": - with open(sample_file, "rb") as f: - sample_ids = np.array(pickle.load(f)) - elif sample_file.suffix == ".zarr": - sample_ids = zarr.load(sample_file) - elif sample_file.suffix == ".npy": - sample_ids = np.load(sample_file) - else: - raise ValueError("Unknown file type for sample_file") - ds_samples = ds.get_metadata()["samples"] - sample_indices = np.where( - np.isin(ds_samples.astype(str), sample_ids.astype(str)) - )[0] - if debug: - sample_indices = sample_indices[:1000] - elif debug: - sample_indices = np.arange(min(1000, len(ds))) - else: - sample_indices = np.arange(len(ds)) - - logger.info( - f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}" - ) - ds = Subset(ds, sample_indices) - - ds_full = ds.dataset # if isinstance(ds, Subset) else ds + ds_full = ds.dataset if isinstance(ds, Subset) else ds collate_fn = getattr(ds_full, "collate_fn", None) n_total_samples = len(ds) - ds_full.rare_embedding.skip_embedding = skip_burdens + ds.rare_embedding.skip_embedding = skip_burdens if chunk is not None: if n_chunks is None: @@ -919,7 +903,7 @@ def compute_burdens( with open(dataset_file, "rb") as f: dataset = pickle.load(f) else: - dataset = make_dataset_(data_config) + dataset = make_dataset_(config) if torch.cuda.is_available(): logger.info("Using GPU") diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index ae9995a0..d730e03d 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -82,9 +82,8 @@ def subset_samples( input_tensor: torch.Tensor, covariates: torch.Tensor, y: torch.Tensor, - sample_ids: torch.Tensor, min_variant_count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all # genes for each sample. @@ -106,23 +105,21 @@ def subset_samples( input_tensor = input_tensor[mask] covariates = covariates[mask] y = y[mask] - sample_ids = sample_ids[mask] logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept") - return input_tensor, covariates, y, sample_ids + return input_tensor, covariates, y def make_dataset_( debug: bool, pickle_only: bool, compression_level: int, - training_dataset_file: Optional[Path], + training_dataset_file: Optional[str], config_file: Union[str, Path], - input_tensor_out_file: Path, - covariates_out_file: Path, - y_out_file: Path, - samples_out_file: Path, + input_tensor_out_file: str, + covariates_out_file: str, + y_out_file: str, ): """ Subfunction of make_dataset() @@ -228,11 +225,10 @@ def make_dataset_( ) covariates = torch.cat([b["x_phenotypes"] for b in batches]) y = torch.cat([b["y"] for b in batches]) - sample_ids = np.concatenate([b["sample"] for b in batches]) logger.info("Subsetting samples by min_variant_count and missing y values") - input_tensor, covariates, y, sample_ids = subset_samples( - input_tensor, covariates, y, sample_ids, config["training"]["min_variant_count"] + input_tensor, covariates, y = subset_samples( + input_tensor, covariates, y, config["training"]["min_variant_count"] ) if not pickle_only: @@ -246,7 +242,6 @@ def make_dataset_( del input_tensor zarr.save_array(covariates_out_file, covariates.numpy()) zarr.save_array(y_out_file, y.numpy()) - zarr.save_array(samples_out_file, sample_ids) # DEBUG return ds.dataset @@ -256,22 +251,20 @@ def make_dataset_( @click.option("--debug", is_flag=True) @click.option("--pickle-only", is_flag=True) @click.option("--compression-level", type=int, default=1) -@click.option("--training-dataset-file", type=click.Path(exists=True, path_type=Path)) -@click.argument("config-file", type=click.Path(exists=True, path_type=Path)) -@click.argument("input-tensor-out-file", type=click.Path(path_type=Path)) -@click.argument("covariates-out-file", type=click.Path(path_type=Path)) -@click.argument("y-out-file", type=click.Path(path_type=Path)) -@click.argument("samples-out-file", type=click.Path(path_type=Path)) +@click.option("--training-dataset-file", type=click.Path()) +@click.argument("config-file", type=click.Path(exists=True)) +@click.argument("input-tensor-out-file", type=click.Path()) +@click.argument("covariates-out-file", type=click.Path()) +@click.argument("y-out-file", type=click.Path()) def make_dataset( debug: bool, pickle_only: bool, compression_level: int, - training_dataset_file: Optional[Path], - config_file: Path, - input_tensor_out_file: Path, - covariates_out_file: Path, - y_out_file: Path, - samples_out_file: Path, + training_dataset_file: Optional[str], + config_file: str, + input_tensor_out_file: str, + covariates_out_file: str, + y_out_file: str, ): """ Uses function make_dataset_() to convert dataset to sparse format and stores the respective data @@ -305,7 +298,6 @@ def make_dataset( input_tensor_out_file, covariates_out_file, y_out_file, - samples_out_file, ) @@ -369,10 +361,9 @@ def __init__( logger.info("Keeping all input tensors in main memory") for pheno, pheno_data in self.data.items(): - n_samples = pheno_data["sample_indices"].shape[0] - if pheno_data["y"].shape == (n_samples, 1): + if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1): pheno_data["y"] = pheno_data["y"].squeeze() - elif pheno_data["y"].shape != (n_samples,): + elif pheno_data["y"].shape != (pheno_data["input_tensor_zarr"].shape[0],): raise NotImplementedError( "Multi-phenotype training is only implemented via multiple y files" ) @@ -593,10 +584,8 @@ def __init__( self.n_annotations = any_pheno_data["input_tensor_zarr"].shape[2] self.n_covariates = any_pheno_data["covariates"].shape[1] - for pheno, pheno_data in self.data.items(): - # n_samples = pheno_data["input_tensor_zarr"].shape[0] - sample_indices = pheno_data["sample_indices"] - n_samples = sample_indices.shape[0] + for _, pheno_data in self.data.items(): + n_samples = pheno_data["input_tensor_zarr"].shape[0] assert pheno_data["covariates"].shape[0] == n_samples assert pheno_data["y"].shape[0] == n_samples @@ -611,33 +600,29 @@ def __init__( samples = self.upsample() n_samples = self.samples.shape[0] logger.info(f"New sample number: {n_samples}") - # else: - # samples = np.arange(n_samples) + else: + samples = np.arange(n_samples) # Sample self.n_samples * train_proportion samples with replacement # for training, use all remaining samples for validation if train_proportion == 1.0: - self.train_samples = sample_indices - self.val_samples = sample_indices + self.train_samples = self.samples + self.val_samples = self.samples else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( - sample_indices, - size=n_train_samples, - replace=sample_with_replacement, + samples, size=n_train_samples, replace=sample_with_replacement ) ) # samples which are not part of train_samples, but in samples # are validation samples. - val_samples = np.setdiff1d(sample_indices, train_samples) - logger.info( - f"{pheno}: Using {train_samples.shape[0]} samples for training, " - f"{val_samples.shape[0]} for validation" - ) - pheno_data["samples"] = {"train": train_samples, "val": val_samples} + pheno_data["samples"] = { + "train": train_samples, + "val": np.setdiff1d(samples, train_samples), + } self.save_hyperparameters( # "min_variant_count", @@ -655,8 +640,6 @@ def upsample(self) -> np.ndarray: does not work at the moment for multi-phenotype training. Needs some minor changes to make it work again """ - raise NotImplementedError - unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -945,19 +928,18 @@ def run_bagging( @cli.command() @click.option("--debug", is_flag=True) -@click.option("--training-gene-file", type=click.Path(exists=True, path_type=Path)) +@click.option("--training-gene-file", type=click.Path(exists=True)) @click.option("--n-trials", type=int, default=1) @click.option("--trial-id", type=int) -@click.option("--samples-to-keep", type=click.Path(exists=True, path_type=Path)) +@click.option("--sample-file", type=click.Path(exists=True)) @click.option( "--phenotype", multiple=True, type=( str, - click.Path(exists=True, path_type=Path), - click.Path(exists=True, path_type=Path), - click.Path(exists=True, path_type=Path), - click.Path(exists=True, path_type=Path), + click.Path(exists=True), + click.Path(exists=True), + click.Path(exists=True), ), ) @click.argument("config-file", type=click.Path(exists=True)) @@ -965,11 +947,11 @@ def run_bagging( @click.argument("hpopt-file", type=click.Path()) def train( debug: bool, - training_gene_file: Optional[Path], + training_gene_file: Optional[str], n_trials: int, trial_id: Optional[int], - samples_to_keep: Optional[Path], - phenotype: Tuple[Tuple[str, Path, Path, Path, Path]], + sample_file: Optional[str], + phenotype: Tuple[Tuple[str, str, str, str]], config_file: str, log_dir: str, hpopt_file: str, @@ -985,8 +967,8 @@ def train( :type n_trials: int :param trial_id: Current trial in range n_trials. (optional) :type trial_id: Optional[int] - :param samples_to_keep: Path to a pickle file specifying which samples should be considered during training. (optional) - :type samples_to_keep: Optional[str] + :param sample_file: Path to a pickle file specifying which samples should be considered during training. (optional) + :type sample_file: Optional[str] :param phenotype: Array of phenotypes, containing an array of paths where the underlying data is stored: - str: Phenotype name - str: Annotated gene variants as zarr file @@ -1020,57 +1002,30 @@ def train( logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory}") logger.info("Loading input data") - - if samples_to_keep is not None or "sample_file" in config: - if samples_to_keep is not None: - keep_sample_file = samples_to_keep - if "sample_file" in config: - logger.warning( - f"--samples-to-keep option overrides sample_file in config.yaml" - ) - else: - keep_sample_file = Path(config["sample_file"]) - - logger.info(f"Using samples from {keep_sample_file}") - if keep_sample_file.suffix == ".pkl": - with open(keep_sample_file, "rb") as f: - sample_ids = np.array(pickle.load(f)) - elif keep_sample_file.suffix == ".zarr": - sample_ids = zarr.load(keep_sample_file) - elif keep_sample_file.suffix == ".npy": - sample_ids = np.load(keep_sample_file) - else: - raise ValueError("Unknown file type for sample_file") + if sample_file is not None: + logger.info(f"Using training samples from {sample_file}") + with open(sample_file, "rb") as f: + samples = pickle.load(f)["training_samples"] + if debug: + samples = [s for s in samples if s < 1000] else: - sample_ids = None + samples = slice(None) data = dict() # pack underlying data into a single dict that can be passed to downstream functions - for pheno, input_tensor_file, covariates_file, y_file, sample_file in phenotype: + for pheno, input_tensor_file, covariates_file, y_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open( input_tensor_file, mode="r" ) # TODO: subset here? - n_samples = data[pheno]["input_tensor_zarr"].shape[0] - - data[pheno]["sample_ids"] = zarr.load(sample_file) - sample_indices = np.arange(n_samples) - if sample_ids is not None: - sample_indices = sample_indices[ - np.isin(data[pheno]["sample_ids"], sample_ids) - ] - - if debug: - sample_indices = sample_indices[:1000] - - data[pheno]["sample_indices"] = sample_indices - data[pheno]["covariates"] = torch.tensor( - zarr.load(covariates_file)[sample_indices] - ) # TODO: or maybe shouldn't subset here? - data[pheno]["y"] = torch.tensor( - zarr.load(y_file)[sample_indices] - ) # TODO: or maybe shouldn't subset here? + zarr.open(covariates_file, mode="r")[:] + )[ + samples + ] # TODO: or maybe shouldn't subset here? + data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[ + samples + ] # TODO: or maybe shouldn't subset here? if training_gene_file is not None: with open(training_gene_file, "rb") as f: diff --git a/deeprvat/utils.py b/deeprvat/utils.py index 2cbbc1ea..a9c18801 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -225,7 +225,6 @@ def safe_merge( right: pd.DataFrame, validate: str = "1:1", equal_row_nums: bool = False, - **kwargs, ): """ Safely merge two pandas DataFrames. @@ -252,7 +251,7 @@ def safe_merge( "left and right dataframes are unequal" ) - merged = pd.merge(left, right, validate=validate, **kwargs) + merged = pd.merge(left, right, validate=validate) try: assert len(merged) == len(left) diff --git a/example/config.yaml b/example/config.yaml index 435a4b22..c6a6a3fb 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -22,9 +22,6 @@ baseline_results: base: baseline_results type: missense/skat -cv_path: sample_files -n_folds: 5 - alpha: 0.05 n_burden_chunks: 2 diff --git a/pipelines/association_testing/association_dataset.snakefile b/pipelines/association_testing/association_dataset.snakefile index 818e43de..9c8ba228 100644 --- a/pipelines/association_testing/association_dataset.snakefile +++ b/pipelines/association_testing/association_dataset.snakefile @@ -1,5 +1,8 @@ configfile: "config.yaml" +debug_flag = config.get('debug', False) +debug = '--debug ' if debug_flag else '' + rule association_dataset: input: config = '{phenotype}/deeprvat/hpopt_config.yaml' @@ -11,5 +14,6 @@ rule association_dataset: priority: 30 shell: 'deeprvat_associate make-dataset ' + + debug + '{input.config} ' '{output}' diff --git a/pipelines/association_testing_pretrained.snakefile b/pipelines/association_testing_pretrained.snakefile index a26b7f1b..9d7799df 100644 --- a/pipelines/association_testing_pretrained.snakefile +++ b/pipelines/association_testing_pretrained.snakefile @@ -9,7 +9,7 @@ training_phenotypes = config["training"].get("phenotypes", phenotypes) n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 -n_avg_chunks = config.get('n_avg_chunks', 40) +n_avg_chunks = config.get('n_avg_chunks', 1) n_trials = config['hyperparameter_optimization']['n_trials'] n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] diff --git a/pipelines/cv_training/cv_burdens.snakefile b/pipelines/cv_training/cv_burdens.snakefile index 0a6960bc..f618e136 100644 --- a/pipelines/cv_training/cv_burdens.snakefile +++ b/pipelines/cv_training/cv_burdens.snakefile @@ -42,6 +42,10 @@ rule make_deeprvat_test_config: # pass the sample file here # then just use this data set nomrally for burden computation use rule association_dataset from deeprvat_associate as deeprvat_association_dataset with: + input: + config="cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml", + output: + "cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/association_dataset.pkl", threads: 4 @@ -89,31 +93,14 @@ rule combine_test_burdens: ) -use rule link_burdens from deeprvat_associate as deeprvat_link_burdens with: - input: - checkpoints = expand( - 'cv_split{cv_split}/deeprvat' / model_path / "repeat_{repeat}/best/bag_{bag}.ckpt", - cv_split=range(cv_splits), repeat=range(n_repeats), bag=range(n_bags) - ), - dataset = 'cv_split0/deeprvat/{phenotype}/deeprvat/association_dataset.pkl', - data_config = 'cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml', - model_config = "cv_split{cv_split}/deeprvat" / model_path / 'config.yaml', +use rule link_burdens from deeprvat_workflow as deeprvat_link_burdens with: params: prefix="cv_split{cv_split}/deeprvat", -use rule compute_burdens from deeprvat_associate as deeprvat_compute_burdens with: - input: - reversed = "cv_split{cv_split}/deeprvat" / model_path / "reverse_finished.tmp", - checkpoints = expand( - 'cv_split{cv_split}/deeprvat' / model_path / "repeat_{repeat}/best/bag_{bag}.ckpt", - cv_split=range(cv_splits), repeat=range(n_repeats), bag=range(n_bags) - ), - dataset = 'cv_split0/deeprvat/{phenotype}/deeprvat/association_dataset.pkl', - data_config = 'cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml', - model_config = "cv_split{cv_split}/deeprvat" / model_path / 'config.yaml', +use rule compute_burdens from deeprvat_workflow as deeprvat_compute_burdens with: params: prefix="cv_split{cv_split}/deeprvat", -use rule reverse_models from deeprvat_associate as deeprvat_reverse_models +use rule reverse_models from deeprvat_workflow as deeprvat_reverse_models diff --git a/pipelines/cv_training/cv_training.snakefile b/pipelines/cv_training/cv_training.snakefile index 0cd08647..2ad880ae 100644 --- a/pipelines/cv_training/cv_training.snakefile +++ b/pipelines/cv_training/cv_training.snakefile @@ -52,36 +52,18 @@ use rule best_training_run from deeprvat_workflow as deeprvat_best_training_run use rule train from deeprvat_workflow as deeprvat_train with: priority: 1000 - input: - config = expand('cv_split{{cv_split}}/deeprvat/{phenotype}/deeprvat/hpopt_config.yaml', - phenotype=training_phenotypes), - input_tensor = expand('cv_split0/deeprvat/{phenotype}/deeprvat/input_tensor.zarr', - phenotype=training_phenotypes), - covariates = expand('cv_split0/deeprvat/{phenotype}/deeprvat/covariates.zarr', - phenotype=training_phenotypes), - y = expand('cv_split0/deeprvat/{phenotype}/deeprvat/y.zarr', - phenotype=training_phenotypes), - samples = expand('cv_split0/deeprvat/{phenotype}/deeprvat/sample_ids.zarr', - phenotype=training_phenotypes), params: prefix = 'cv_split{cv_split}/deeprvat', phenotypes = " ".join( #TODO like need the prefix here as well [f"--phenotype {p} " - f"cv_split0/deeprvat/{p}/deeprvat/input_tensor.zarr " - f"cv_split0/deeprvat/{p}/deeprvat/covariates.zarr " - f"cv_split0/deeprvat/{p}/deeprvat/y.zarr " - f"cv_split0/deeprvat/{p}/deeprvat/sample_ids.zarr " + f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/input_tensor.zarr " + f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/covariates.zarr " + f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/y.zarr" for p in training_phenotypes]) use rule training_dataset from deeprvat_workflow as deeprvat_training_dataset - # output: - # input_tensor=directory("{phenotype}/deeprvat/input_tensor.zarr"), - # covariates=directory("{phenotype}/deeprvat/covariates.zarr"), - # y=directory("{phenotype}/deeprvat/y.zarr"), - # sample_ids=directory("{phenotype}/deeprvat/sample_ids.zarr"), - -# use rule training_dataset_pickle from deeprvat_workflow as deeprvat_training_dataset_pickle +use rule training_dataset_pickle from deeprvat_workflow as deeprvat_training_dataset_pickle use rule config from deeprvat_workflow as deeprvat_config with: input: @@ -97,4 +79,12 @@ use rule config from deeprvat_workflow as deeprvat_config with: for b in input.baseline ]) if wildcards.phenotype in training_phenotypes else ' ', baseline_out = lambda wildcards: f'--baseline-results-out cv_split{wildcards.cv_split}/deeprvat/{wildcards.phenotype}/deeprvat/baseline_results.parquet' if wildcards.phenotype in training_phenotypes else ' ', - seed_genes_out = lambda wildcards: f'--seed-genes-out cv_split{wildcards.cv_split}/deeprvat/{wildcards.phenotype}/deeprvat/seed_genes.parquet' if wildcards.phenotype in training_phenotypes else ' ' + seed_genes_out = lambda wildcards: f'--seed-genes-out cv_split{wildcards.cv_split}/deeprvat/{wildcards.phenotype}/deeprvat/seed_genes.parquet' if wildcards.phenotype in training_phenotypes else ' ', + association_only = lambda wildcards: f'--association-only' if wildcards.phenotype not in training_phenotypes else ' ' + + + + + + + diff --git a/pipelines/cv_training/cv_training_association_testing.snakefile b/pipelines/cv_training/cv_training_association_testing.snakefile index 14443e83..bb825587 100644 --- a/pipelines/cv_training/cv_training_association_testing.snakefile +++ b/pipelines/cv_training/cv_training_association_testing.snakefile @@ -14,7 +14,7 @@ burden_phenotype = phenotypes[0] n_burden_chunks = config.get("n_burden_chunks", 1) if not debug_flag else 2 n_regression_chunks = config.get("n_regression_chunks", 40) if not debug_flag else 2 -n_avg_chunks = config.get('n_avg_chunks', 40) +n_avg_chunks = config.get('n_avg_chunks', 1) n_trials = config["hyperparameter_optimization"]["n_trials"] n_bags = config["training"]["n_bags"] if not debug_flag else 3 n_repeats = config["n_repeats"] @@ -28,21 +28,26 @@ n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) wildcard_constraints: repeat="\d+", trial="\d+", - cv_split="\d+", - phenotype="[A-z0-9_]+", cv_splits = config.get("n_folds", 5) cv_exp = True - - include: "cv_training.snakefile" include: "cv_burdens.snakefile" include: "../association_testing/burdens.snakefile" include: "../association_testing/regress_eval.snakefile" +rule all_evaluate: #regress_eval.snakefile + input: + significant=expand( + "{phenotype}/deeprvat/eval/significant.parquet", phenotype=phenotypes + ), + results=expand( + "{phenotype}/deeprvat/eval/all_results.parquet", phenotype=phenotypes + ), + rule all_regression: #regress_eval.snakefile input: @@ -80,25 +85,6 @@ rule all_training: #cv_training.snakefile cv_split=range(cv_splits), ), -rule all_training_dataset: #cv_training.snakefile - input: - expand( - "cv_split0/deeprvat/{phenotype}/deeprvat/input_tensor.zarr", - phenotype=training_phenotypes, - ), - expand( - "cv_split0/deeprvat/{phenotype}/deeprvat/covariates.zarr", - phenotype=training_phenotypes, - ), - expand( - "cv_split0/deeprvat/{phenotype}/deeprvat/y.zarr", - phenotype=training_phenotypes, - ), - expand( - "cv_split0/deeprvat/{phenotype}/deeprvat/sample_ids.zarr", - phenotype=training_phenotypes, - ), - rule all_config: #cv_training.snakefile input: diff --git a/pipelines/training/config.snakefile b/pipelines/training/config.snakefile index 93075fff..3799ef7e 100644 --- a/pipelines/training/config.snakefile +++ b/pipelines/training/config.snakefile @@ -33,11 +33,15 @@ rule config: baseline_out=lambda wildcards: f"--baseline-results-out {wildcards.phenotype}/deeprvat/baseline_results.parquet" if wildcards.phenotype in training_phenotypes else " ", + association_only=lambda wildcards: f"--association-only" + if wildcards.phenotype not in training_phenotypes + else " ", shell: ( "deeprvat_config update-config " "--phenotype {wildcards.phenotype} " - "{params.baseline_results}" + "{params.association_only} " + "{params.baseline_results} " "{params.baseline_out} " "{params.seed_genes_out} " "{input.config} " diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile index 4d8e9cd7..f33fd5b6 100644 --- a/pipelines/training/train.snakefile +++ b/pipelines/training/train.snakefile @@ -51,8 +51,7 @@ rule train: [f"--phenotype {p} " f"{p}/deeprvat/input_tensor.zarr " f"{p}/deeprvat/covariates.zarr " - f"{p}/deeprvat/y.zarr " - f"{p}/deeprvat/sample_ids.zarr " + f"{p}/deeprvat/y.zarr" for p in training_phenotypes]), prefix = '.', priority: 1000 @@ -65,8 +64,7 @@ rule train: + debug + '--trial-id {{2}} ' "{params.phenotypes} " - # 'config.yaml ' - "{params.prefix}/config.yaml " + 'config.yaml ' '{params.prefix}/{model_path}/repeat_{{1}}/trial{{2}} ' "{params.prefix}/{model_path}/repeat_{{1}}/hyperparameter_optimization.db '&&' " "touch {params.prefix}/{model_path}/repeat_{{1}}/trial{{2}}/finished.tmp " diff --git a/pipelines/training/training_dataset.snakefile b/pipelines/training/training_dataset.snakefile index 3c40760b..2cf00229 100644 --- a/pipelines/training/training_dataset.snakefile +++ b/pipelines/training/training_dataset.snakefile @@ -1,16 +1,15 @@ rule training_dataset: input: config="{phenotype}/deeprvat/hpopt_config.yaml", - # training_dataset="{phenotype}/deeprvat/training_dataset.pkl", + training_dataset="{phenotype}/deeprvat/training_dataset.pkl", output: input_tensor=directory("{phenotype}/deeprvat/input_tensor.zarr"), covariates=directory("{phenotype}/deeprvat/covariates.zarr"), y=directory("{phenotype}/deeprvat/y.zarr"), - sample_ids=directory("{phenotype}/deeprvat/sample_ids.zarr"), threads: 8 resources: mem_mb=lambda wildcards, attempt: 32000 + 12000 * attempt, - load=32000, + load=16000, priority: 5000 shell: ( @@ -19,12 +18,11 @@ rule training_dataset: + "--compression-level " + str(tensor_compression_level) + " " - # "--training-dataset-file {input.training_dataset} " + "--training-dataset-file {input.training_dataset} " "{input.config} " "{output.input_tensor} " "{output.covariates} " - "{output.y} " - "{output.sample_ids}" + "{output.y}" ) @@ -32,11 +30,11 @@ rule training_dataset_pickle: input: "{phenotype}/deeprvat/hpopt_config.yaml", output: - temp("{phenotype}/deeprvat/training_dataset.pkl"), + "{phenotype}/deeprvat/training_dataset.pkl", threads: 1 resources: mem_mb=40000, # lambda wildcards, attempt: 38000 + 12000 * attempt - load=40000, + load=16000, shell: ( "deeprvat_train make-dataset " @@ -44,4 +42,4 @@ rule training_dataset_pickle: "--training-dataset-file {output} " "{input} " "dummy dummy dummy" - ) + ) \ No newline at end of file diff --git a/pipelines/training_association_testing.snakefile b/pipelines/training_association_testing.snakefile index 8c360136..b816752e 100644 --- a/pipelines/training_association_testing.snakefile +++ b/pipelines/training_association_testing.snakefile @@ -9,7 +9,7 @@ training_phenotypes = config["training"].get("phenotypes", phenotypes) n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 -n_avg_chunks = config.get('n_avg_chunks', 40) +n_avg_chunks = config.get('n_avg_chunks', 1) n_trials = config['hyperparameter_optimization']['n_trials'] n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] diff --git a/tests/deeprvat/test_train.py b/tests/deeprvat/test_train.py index 60d6605e..7fbcbbfa 100644 --- a/tests/deeprvat/test_train.py +++ b/tests/deeprvat/test_train.py @@ -60,7 +60,6 @@ def make_multipheno_data(): ) data[p]["input_tensor"] = data[p]["input_tensor_zarr"][:] data[p]["samples"] = {"train": np.arange(data[p]["y"].shape[0])} - data[p]["sample_indices"] = np.arange(data[p]["y"].shape[0]) return data @@ -146,10 +145,9 @@ def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): yaml.dump(config, f) # This is the function we want to test - input_tensor_out_file = tmp_path / "input_tensor.zarr" - covariates_out_file = tmp_path / "covariates.zarr" - y_out_file = tmp_path / "y.zarr" - samples_out_file = tmp_path / "sample_ids.zarr" + input_tensor_out_file = str(tmp_path / "input_tensor.zarr") + covariates_out_file = str(tmp_path / "covariates.zarr") + y_out_file = str(tmp_path / "y.zarr") logger.info("Constructing test dataset") test_ds = make_dataset_( False, @@ -160,7 +158,6 @@ def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): input_tensor_out_file, covariates_out_file, y_out_file, - samples_out_file, ) # Load the data it output