diff --git a/deeprvat/cv_utils.py b/deeprvat/cv_utils.py index 68a285e5..7c97271e 100644 --- a/deeprvat/cv_utils.py +++ b/deeprvat/cv_utils.py @@ -59,11 +59,9 @@ def spread_config( cv_path = f"{config_template['cv_path']}/{n_folds}_fold" for module in data_modules: config = copy.deepcopy(config_template) - data_slots = DATA_SLOT_DICT[module] - for data_slot in data_slots: - sample_file = f"{cv_path}/samples_{split}{fold}.pkl" - logger.info(f"setting sample file {sample_file}") - config[data_slot]["dataset_config"]["sample_file"] = sample_file + sample_file = f"{cv_path}/samples_{split}{fold}.pkl" + logger.info(f"setting sample file {sample_file}") + config["sample_file"] = sample_file if (module == "deeprvat") | (module == "deeprvat_pretrained"): logger.info("Writing baseline directories") @@ -91,8 +89,7 @@ def generate_test_config(input_config, out_file, fold, n_folds): split = "test" sample_file = f"{cv_path}/samples_{split}{fold}.pkl" logger.info(f"setting sample file {sample_file}") - for data_slot in DATA_SLOT_DICT["deeprvat"]: - config[data_slot]["dataset_config"]["sample_file"] = sample_file + config["sample_file"] = sample_file with open(out_file, "w") as f: yaml.dump(config, f) diff --git a/deeprvat/data/dense_gt.py b/deeprvat/data/dense_gt.py index 12247518..7ebdaea9 100644 --- a/deeprvat/data/dense_gt.py +++ b/deeprvat/data/dense_gt.py @@ -362,7 +362,7 @@ def setup_phenotypes( # account for the fact that genotypes.h5 and phenotype_df can have different # orders of their samples self.index_map_geno, _ = get_matched_sample_indices( - samples_gt.astype(int), self.samples.astype(int) + samples_gt.astype(str), self.samples.astype(str) ) # get_matched_sample_indices is a much, much faster implementation of the code below # self.index_map_geno = [np.where(samples_gt.astype(int) == i) for i in self.samples.astype(int)] @@ -614,13 +614,20 @@ def setup_variants( "Annotation dataframe has inconsistent allele frequency values" ) variants_with_af = safe_merge( - variants[["id"]].reset_index(drop=True), af_annotation + variants[["id"]].reset_index(drop=True), af_annotation, how="left" ) assert np.all( variants_with_af["id"].to_numpy() == variants["id"].to_numpy() ) - mask = (variants_with_af[af_col] >= af_threshold) & ( - variants_with_af[af_col] <= 1 - af_threshold + af_isna = variants_with_af[af_col].isna() + if af_isna.sum() > 0: + logger.warning( + f"Dropping {af_isna.sum()} variants missing from annotation dataframe" + ) + mask = ( + (~af_isna) + & (variants_with_af[af_col] >= af_threshold) + & (variants_with_af[af_col] <= 1 - af_threshold) ) mask = mask.to_numpy() del variants_with_af @@ -931,11 +938,10 @@ def get_metadata(self) -> Dict[str, Any]: result = { "variant_metadata": self.variants[ ["id", "common_variant_mask", "rare_variant_mask", "matrix_index"] - ] + ], + "samples": self.samples, } if self.use_rare_variants: if hasattr(self.rare_embedding, "get_metadata"): - result.update( - {"rare_embedding_metadata": self.rare_embedding.get_metadata()} - ) + result["rare_embedding_metadata"] = self.rare_embedding.get_metadata() return result diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 5af2e770..03404067 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -7,7 +7,7 @@ import sys from pathlib import Path from pprint import pprint -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import click import dask.dataframe as dd @@ -115,9 +115,7 @@ def cli(): def make_dataset_( config: Dict, - debug: bool = False, data_key="data", - samples: Optional[List[int]] = None, ) -> Dataset: """ Create a dataset based on the configuration. @@ -149,29 +147,17 @@ def make_dataset_( **copy.deepcopy(data_config["dataset_config"]), ) - restrict_samples = config.get("restrict_samples", None) - if debug: - logger.info("Debug flag set; Using only 1000 samples") - ds = Subset(ds, range(1_000)) - elif samples is not None: - ds = Subset(ds, samples) - elif restrict_samples is not None: - ds = Subset(ds, range(restrict_samples)) - return ds @cli.command() -@click.option("--debug", is_flag=True) @click.option("--data-key", type=str, default="data") -@click.argument("config-file", type=click.Path(exists=True)) -@click.argument("out-file", type=click.Path()) -def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): +@click.argument("config-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("out-file", type=click.Path(path_type=Path)) +def make_dataset(data_key: str, config_file: Path, out_file: Path): """ Create a dataset based on the provided configuration and save to a pickle file. - :param debug: Flag for debugging. - :type debug: bool :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". :type data_key: str :param config_file: Path to the configuration file. @@ -183,7 +169,7 @@ def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): with open(config_file) as f: config = yaml.safe_load(f) - ds = make_dataset_(config, debug=debug, data_key=data_key) + ds = make_dataset_(config, data_key=data_key) with open(out_file, "wb") as f: pickle.dump(ds, f) @@ -236,6 +222,8 @@ def compute_burdens_( .. note:: Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ + logger.setLevel(logging.INFO) + if not skip_burdens: logger.info("agg_models[*][*].reverse:") pprint( @@ -247,10 +235,38 @@ def compute_burdens_( data_config = config["data"] - ds_full = ds.dataset if isinstance(ds, Subset) else ds + if "sample_file" in config: + sample_file = Path(config["sample_file"]) + logger.info(f"Using samples from {sample_file}") + if sample_file.suffix == ".pkl": + with open(sample_file, "rb") as f: + sample_ids = np.array(pickle.load(f)) + elif sample_file.suffix == ".zarr": + sample_ids = zarr.load(sample_file) + elif sample_file.suffix == ".npy": + sample_ids = np.load(sample_file) + else: + raise ValueError("Unknown file type for sample_file") + ds_samples = ds.get_metadata()["samples"] + sample_indices = np.where( + np.isin(ds_samples.astype(str), sample_ids.astype(str)) + )[0] + if debug: + sample_indices = sample_indices[:1000] + elif debug: + sample_indices = np.arange(min(1000, len(ds))) + else: + sample_indices = np.arange(len(ds)) + + logger.info( + f"Computing gene impairment for {sample_indices.shape[0]} samples: {sample_indices}" + ) + ds = Subset(ds, sample_indices) + + ds_full = ds.dataset # if isinstance(ds, Subset) else ds collate_fn = getattr(ds_full, "collate_fn", None) n_total_samples = len(ds) - ds.rare_embedding.skip_embedding = skip_burdens + ds_full.rare_embedding.skip_embedding = skip_burdens if chunk is not None: if n_chunks is None: @@ -903,7 +919,7 @@ def compute_burdens( with open(dataset_file, "rb") as f: dataset = pickle.load(f) else: - dataset = make_dataset_(config) + dataset = make_dataset_(data_config) if torch.cuda.is_available(): logger.info("Using GPU") diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index d730e03d..ae9995a0 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -82,8 +82,9 @@ def subset_samples( input_tensor: torch.Tensor, covariates: torch.Tensor, y: torch.Tensor, + sample_ids: torch.Tensor, min_variant_count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all # genes for each sample. @@ -105,21 +106,23 @@ def subset_samples( input_tensor = input_tensor[mask] covariates = covariates[mask] y = y[mask] + sample_ids = sample_ids[mask] logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept") - return input_tensor, covariates, y + return input_tensor, covariates, y, sample_ids def make_dataset_( debug: bool, pickle_only: bool, compression_level: int, - training_dataset_file: Optional[str], + training_dataset_file: Optional[Path], config_file: Union[str, Path], - input_tensor_out_file: str, - covariates_out_file: str, - y_out_file: str, + input_tensor_out_file: Path, + covariates_out_file: Path, + y_out_file: Path, + samples_out_file: Path, ): """ Subfunction of make_dataset() @@ -225,10 +228,11 @@ def make_dataset_( ) covariates = torch.cat([b["x_phenotypes"] for b in batches]) y = torch.cat([b["y"] for b in batches]) + sample_ids = np.concatenate([b["sample"] for b in batches]) logger.info("Subsetting samples by min_variant_count and missing y values") - input_tensor, covariates, y = subset_samples( - input_tensor, covariates, y, config["training"]["min_variant_count"] + input_tensor, covariates, y, sample_ids = subset_samples( + input_tensor, covariates, y, sample_ids, config["training"]["min_variant_count"] ) if not pickle_only: @@ -242,6 +246,7 @@ def make_dataset_( del input_tensor zarr.save_array(covariates_out_file, covariates.numpy()) zarr.save_array(y_out_file, y.numpy()) + zarr.save_array(samples_out_file, sample_ids) # DEBUG return ds.dataset @@ -251,20 +256,22 @@ def make_dataset_( @click.option("--debug", is_flag=True) @click.option("--pickle-only", is_flag=True) @click.option("--compression-level", type=int, default=1) -@click.option("--training-dataset-file", type=click.Path()) -@click.argument("config-file", type=click.Path(exists=True)) -@click.argument("input-tensor-out-file", type=click.Path()) -@click.argument("covariates-out-file", type=click.Path()) -@click.argument("y-out-file", type=click.Path()) +@click.option("--training-dataset-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("config-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("input-tensor-out-file", type=click.Path(path_type=Path)) +@click.argument("covariates-out-file", type=click.Path(path_type=Path)) +@click.argument("y-out-file", type=click.Path(path_type=Path)) +@click.argument("samples-out-file", type=click.Path(path_type=Path)) def make_dataset( debug: bool, pickle_only: bool, compression_level: int, - training_dataset_file: Optional[str], - config_file: str, - input_tensor_out_file: str, - covariates_out_file: str, - y_out_file: str, + training_dataset_file: Optional[Path], + config_file: Path, + input_tensor_out_file: Path, + covariates_out_file: Path, + y_out_file: Path, + samples_out_file: Path, ): """ Uses function make_dataset_() to convert dataset to sparse format and stores the respective data @@ -298,6 +305,7 @@ def make_dataset( input_tensor_out_file, covariates_out_file, y_out_file, + samples_out_file, ) @@ -361,9 +369,10 @@ def __init__( logger.info("Keeping all input tensors in main memory") for pheno, pheno_data in self.data.items(): - if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1): + n_samples = pheno_data["sample_indices"].shape[0] + if pheno_data["y"].shape == (n_samples, 1): pheno_data["y"] = pheno_data["y"].squeeze() - elif pheno_data["y"].shape != (pheno_data["input_tensor_zarr"].shape[0],): + elif pheno_data["y"].shape != (n_samples,): raise NotImplementedError( "Multi-phenotype training is only implemented via multiple y files" ) @@ -584,8 +593,10 @@ def __init__( self.n_annotations = any_pheno_data["input_tensor_zarr"].shape[2] self.n_covariates = any_pheno_data["covariates"].shape[1] - for _, pheno_data in self.data.items(): - n_samples = pheno_data["input_tensor_zarr"].shape[0] + for pheno, pheno_data in self.data.items(): + # n_samples = pheno_data["input_tensor_zarr"].shape[0] + sample_indices = pheno_data["sample_indices"] + n_samples = sample_indices.shape[0] assert pheno_data["covariates"].shape[0] == n_samples assert pheno_data["y"].shape[0] == n_samples @@ -600,29 +611,33 @@ def __init__( samples = self.upsample() n_samples = self.samples.shape[0] logger.info(f"New sample number: {n_samples}") - else: - samples = np.arange(n_samples) + # else: + # samples = np.arange(n_samples) # Sample self.n_samples * train_proportion samples with replacement # for training, use all remaining samples for validation if train_proportion == 1.0: - self.train_samples = self.samples - self.val_samples = self.samples + self.train_samples = sample_indices + self.val_samples = sample_indices else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( - samples, size=n_train_samples, replace=sample_with_replacement + sample_indices, + size=n_train_samples, + replace=sample_with_replacement, ) ) # samples which are not part of train_samples, but in samples # are validation samples. - pheno_data["samples"] = { - "train": train_samples, - "val": np.setdiff1d(samples, train_samples), - } + val_samples = np.setdiff1d(sample_indices, train_samples) + logger.info( + f"{pheno}: Using {train_samples.shape[0]} samples for training, " + f"{val_samples.shape[0]} for validation" + ) + pheno_data["samples"] = {"train": train_samples, "val": val_samples} self.save_hyperparameters( # "min_variant_count", @@ -640,6 +655,8 @@ def upsample(self) -> np.ndarray: does not work at the moment for multi-phenotype training. Needs some minor changes to make it work again """ + raise NotImplementedError + unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -928,18 +945,19 @@ def run_bagging( @cli.command() @click.option("--debug", is_flag=True) -@click.option("--training-gene-file", type=click.Path(exists=True)) +@click.option("--training-gene-file", type=click.Path(exists=True, path_type=Path)) @click.option("--n-trials", type=int, default=1) @click.option("--trial-id", type=int) -@click.option("--sample-file", type=click.Path(exists=True)) +@click.option("--samples-to-keep", type=click.Path(exists=True, path_type=Path)) @click.option( "--phenotype", multiple=True, type=( str, - click.Path(exists=True), - click.Path(exists=True), - click.Path(exists=True), + click.Path(exists=True, path_type=Path), + click.Path(exists=True, path_type=Path), + click.Path(exists=True, path_type=Path), + click.Path(exists=True, path_type=Path), ), ) @click.argument("config-file", type=click.Path(exists=True)) @@ -947,11 +965,11 @@ def run_bagging( @click.argument("hpopt-file", type=click.Path()) def train( debug: bool, - training_gene_file: Optional[str], + training_gene_file: Optional[Path], n_trials: int, trial_id: Optional[int], - sample_file: Optional[str], - phenotype: Tuple[Tuple[str, str, str, str]], + samples_to_keep: Optional[Path], + phenotype: Tuple[Tuple[str, Path, Path, Path, Path]], config_file: str, log_dir: str, hpopt_file: str, @@ -967,8 +985,8 @@ def train( :type n_trials: int :param trial_id: Current trial in range n_trials. (optional) :type trial_id: Optional[int] - :param sample_file: Path to a pickle file specifying which samples should be considered during training. (optional) - :type sample_file: Optional[str] + :param samples_to_keep: Path to a pickle file specifying which samples should be considered during training. (optional) + :type samples_to_keep: Optional[str] :param phenotype: Array of phenotypes, containing an array of paths where the underlying data is stored: - str: Phenotype name - str: Annotated gene variants as zarr file @@ -1002,30 +1020,57 @@ def train( logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory}") logger.info("Loading input data") - if sample_file is not None: - logger.info(f"Using training samples from {sample_file}") - with open(sample_file, "rb") as f: - samples = pickle.load(f)["training_samples"] - if debug: - samples = [s for s in samples if s < 1000] + + if samples_to_keep is not None or "sample_file" in config: + if samples_to_keep is not None: + keep_sample_file = samples_to_keep + if "sample_file" in config: + logger.warning( + f"--samples-to-keep option overrides sample_file in config.yaml" + ) + else: + keep_sample_file = Path(config["sample_file"]) + + logger.info(f"Using samples from {keep_sample_file}") + if keep_sample_file.suffix == ".pkl": + with open(keep_sample_file, "rb") as f: + sample_ids = np.array(pickle.load(f)) + elif keep_sample_file.suffix == ".zarr": + sample_ids = zarr.load(keep_sample_file) + elif keep_sample_file.suffix == ".npy": + sample_ids = np.load(keep_sample_file) + else: + raise ValueError("Unknown file type for sample_file") else: - samples = slice(None) + sample_ids = None data = dict() # pack underlying data into a single dict that can be passed to downstream functions - for pheno, input_tensor_file, covariates_file, y_file in phenotype: + for pheno, input_tensor_file, covariates_file, y_file, sample_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open( input_tensor_file, mode="r" ) # TODO: subset here? + n_samples = data[pheno]["input_tensor_zarr"].shape[0] + + data[pheno]["sample_ids"] = zarr.load(sample_file) + sample_indices = np.arange(n_samples) + if sample_ids is not None: + sample_indices = sample_indices[ + np.isin(data[pheno]["sample_ids"], sample_ids) + ] + + if debug: + sample_indices = sample_indices[:1000] + + data[pheno]["sample_indices"] = sample_indices + data[pheno]["covariates"] = torch.tensor( - zarr.open(covariates_file, mode="r")[:] - )[ - samples - ] # TODO: or maybe shouldn't subset here? - data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[ - samples - ] # TODO: or maybe shouldn't subset here? + zarr.load(covariates_file)[sample_indices] + ) # TODO: or maybe shouldn't subset here? + data[pheno]["y"] = torch.tensor( + zarr.load(y_file)[sample_indices] + ) # TODO: or maybe shouldn't subset here? if training_gene_file is not None: with open(training_gene_file, "rb") as f: diff --git a/deeprvat/utils.py b/deeprvat/utils.py index a9c18801..2cbbc1ea 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -225,6 +225,7 @@ def safe_merge( right: pd.DataFrame, validate: str = "1:1", equal_row_nums: bool = False, + **kwargs, ): """ Safely merge two pandas DataFrames. @@ -251,7 +252,7 @@ def safe_merge( "left and right dataframes are unequal" ) - merged = pd.merge(left, right, validate=validate) + merged = pd.merge(left, right, validate=validate, **kwargs) try: assert len(merged) == len(left) diff --git a/example/config.yaml b/example/config.yaml index c6a6a3fb..435a4b22 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -22,6 +22,9 @@ baseline_results: base: baseline_results type: missense/skat +cv_path: sample_files +n_folds: 5 + alpha: 0.05 n_burden_chunks: 2 diff --git a/pipelines/association_testing/association_dataset.snakefile b/pipelines/association_testing/association_dataset.snakefile index 9c8ba228..818e43de 100644 --- a/pipelines/association_testing/association_dataset.snakefile +++ b/pipelines/association_testing/association_dataset.snakefile @@ -1,8 +1,5 @@ configfile: "config.yaml" -debug_flag = config.get('debug', False) -debug = '--debug ' if debug_flag else '' - rule association_dataset: input: config = '{phenotype}/deeprvat/hpopt_config.yaml' @@ -14,6 +11,5 @@ rule association_dataset: priority: 30 shell: 'deeprvat_associate make-dataset ' - + debug + '{input.config} ' '{output}' diff --git a/pipelines/cv_training/cv_burdens.snakefile b/pipelines/cv_training/cv_burdens.snakefile index f618e136..0a6960bc 100644 --- a/pipelines/cv_training/cv_burdens.snakefile +++ b/pipelines/cv_training/cv_burdens.snakefile @@ -42,10 +42,6 @@ rule make_deeprvat_test_config: # pass the sample file here # then just use this data set nomrally for burden computation use rule association_dataset from deeprvat_associate as deeprvat_association_dataset with: - input: - config="cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml", - output: - "cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/association_dataset.pkl", threads: 4 @@ -93,14 +89,31 @@ rule combine_test_burdens: ) -use rule link_burdens from deeprvat_workflow as deeprvat_link_burdens with: +use rule link_burdens from deeprvat_associate as deeprvat_link_burdens with: + input: + checkpoints = expand( + 'cv_split{cv_split}/deeprvat' / model_path / "repeat_{repeat}/best/bag_{bag}.ckpt", + cv_split=range(cv_splits), repeat=range(n_repeats), bag=range(n_bags) + ), + dataset = 'cv_split0/deeprvat/{phenotype}/deeprvat/association_dataset.pkl', + data_config = 'cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml', + model_config = "cv_split{cv_split}/deeprvat" / model_path / 'config.yaml', params: prefix="cv_split{cv_split}/deeprvat", -use rule compute_burdens from deeprvat_workflow as deeprvat_compute_burdens with: +use rule compute_burdens from deeprvat_associate as deeprvat_compute_burdens with: + input: + reversed = "cv_split{cv_split}/deeprvat" / model_path / "reverse_finished.tmp", + checkpoints = expand( + 'cv_split{cv_split}/deeprvat' / model_path / "repeat_{repeat}/best/bag_{bag}.ckpt", + cv_split=range(cv_splits), repeat=range(n_repeats), bag=range(n_bags) + ), + dataset = 'cv_split0/deeprvat/{phenotype}/deeprvat/association_dataset.pkl', + data_config = 'cv_split{cv_split}/deeprvat/{phenotype}/deeprvat/hpopt_config_test.yaml', + model_config = "cv_split{cv_split}/deeprvat" / model_path / 'config.yaml', params: prefix="cv_split{cv_split}/deeprvat", -use rule reverse_models from deeprvat_workflow as deeprvat_reverse_models +use rule reverse_models from deeprvat_associate as deeprvat_reverse_models diff --git a/pipelines/cv_training/cv_training.snakefile b/pipelines/cv_training/cv_training.snakefile index 3c4bb674..0cd08647 100644 --- a/pipelines/cv_training/cv_training.snakefile +++ b/pipelines/cv_training/cv_training.snakefile @@ -52,18 +52,36 @@ use rule best_training_run from deeprvat_workflow as deeprvat_best_training_run use rule train from deeprvat_workflow as deeprvat_train with: priority: 1000 + input: + config = expand('cv_split{{cv_split}}/deeprvat/{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=training_phenotypes), + input_tensor = expand('cv_split0/deeprvat/{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes), + covariates = expand('cv_split0/deeprvat/{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes), + y = expand('cv_split0/deeprvat/{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes), + samples = expand('cv_split0/deeprvat/{phenotype}/deeprvat/sample_ids.zarr', + phenotype=training_phenotypes), params: prefix = 'cv_split{cv_split}/deeprvat', phenotypes = " ".join( #TODO like need the prefix here as well [f"--phenotype {p} " - f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/input_tensor.zarr " - f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/covariates.zarr " - f"cv_split{{cv_split}}/deeprvat/{p}/deeprvat/y.zarr" + f"cv_split0/deeprvat/{p}/deeprvat/input_tensor.zarr " + f"cv_split0/deeprvat/{p}/deeprvat/covariates.zarr " + f"cv_split0/deeprvat/{p}/deeprvat/y.zarr " + f"cv_split0/deeprvat/{p}/deeprvat/sample_ids.zarr " for p in training_phenotypes]) use rule training_dataset from deeprvat_workflow as deeprvat_training_dataset + # output: + # input_tensor=directory("{phenotype}/deeprvat/input_tensor.zarr"), + # covariates=directory("{phenotype}/deeprvat/covariates.zarr"), + # y=directory("{phenotype}/deeprvat/y.zarr"), + # sample_ids=directory("{phenotype}/deeprvat/sample_ids.zarr"), + -use rule training_dataset_pickle from deeprvat_workflow as deeprvat_training_dataset_pickle +# use rule training_dataset_pickle from deeprvat_workflow as deeprvat_training_dataset_pickle use rule config from deeprvat_workflow as deeprvat_config with: input: @@ -80,11 +98,3 @@ use rule config from deeprvat_workflow as deeprvat_config with: ]) if wildcards.phenotype in training_phenotypes else ' ', baseline_out = lambda wildcards: f'--baseline-results-out cv_split{wildcards.cv_split}/deeprvat/{wildcards.phenotype}/deeprvat/baseline_results.parquet' if wildcards.phenotype in training_phenotypes else ' ', seed_genes_out = lambda wildcards: f'--seed-genes-out cv_split{wildcards.cv_split}/deeprvat/{wildcards.phenotype}/deeprvat/seed_genes.parquet' if wildcards.phenotype in training_phenotypes else ' ' - - - - - - - - diff --git a/pipelines/cv_training/cv_training_association_testing.snakefile b/pipelines/cv_training/cv_training_association_testing.snakefile index 7effce79..14443e83 100644 --- a/pipelines/cv_training/cv_training_association_testing.snakefile +++ b/pipelines/cv_training/cv_training_association_testing.snakefile @@ -28,6 +28,8 @@ n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) wildcard_constraints: repeat="\d+", trial="\d+", + cv_split="\d+", + phenotype="[A-z0-9_]+", cv_splits = config.get("n_folds", 5) @@ -78,6 +80,25 @@ rule all_training: #cv_training.snakefile cv_split=range(cv_splits), ), +rule all_training_dataset: #cv_training.snakefile + input: + expand( + "cv_split0/deeprvat/{phenotype}/deeprvat/input_tensor.zarr", + phenotype=training_phenotypes, + ), + expand( + "cv_split0/deeprvat/{phenotype}/deeprvat/covariates.zarr", + phenotype=training_phenotypes, + ), + expand( + "cv_split0/deeprvat/{phenotype}/deeprvat/y.zarr", + phenotype=training_phenotypes, + ), + expand( + "cv_split0/deeprvat/{phenotype}/deeprvat/sample_ids.zarr", + phenotype=training_phenotypes, + ), + rule all_config: #cv_training.snakefile input: diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile index f33fd5b6..4d8e9cd7 100644 --- a/pipelines/training/train.snakefile +++ b/pipelines/training/train.snakefile @@ -51,7 +51,8 @@ rule train: [f"--phenotype {p} " f"{p}/deeprvat/input_tensor.zarr " f"{p}/deeprvat/covariates.zarr " - f"{p}/deeprvat/y.zarr" + f"{p}/deeprvat/y.zarr " + f"{p}/deeprvat/sample_ids.zarr " for p in training_phenotypes]), prefix = '.', priority: 1000 @@ -64,7 +65,8 @@ rule train: + debug + '--trial-id {{2}} ' "{params.phenotypes} " - 'config.yaml ' + # 'config.yaml ' + "{params.prefix}/config.yaml " '{params.prefix}/{model_path}/repeat_{{1}}/trial{{2}} ' "{params.prefix}/{model_path}/repeat_{{1}}/hyperparameter_optimization.db '&&' " "touch {params.prefix}/{model_path}/repeat_{{1}}/trial{{2}}/finished.tmp " diff --git a/pipelines/training/training_dataset.snakefile b/pipelines/training/training_dataset.snakefile index 2cf00229..3c40760b 100644 --- a/pipelines/training/training_dataset.snakefile +++ b/pipelines/training/training_dataset.snakefile @@ -1,15 +1,16 @@ rule training_dataset: input: config="{phenotype}/deeprvat/hpopt_config.yaml", - training_dataset="{phenotype}/deeprvat/training_dataset.pkl", + # training_dataset="{phenotype}/deeprvat/training_dataset.pkl", output: input_tensor=directory("{phenotype}/deeprvat/input_tensor.zarr"), covariates=directory("{phenotype}/deeprvat/covariates.zarr"), y=directory("{phenotype}/deeprvat/y.zarr"), + sample_ids=directory("{phenotype}/deeprvat/sample_ids.zarr"), threads: 8 resources: mem_mb=lambda wildcards, attempt: 32000 + 12000 * attempt, - load=16000, + load=32000, priority: 5000 shell: ( @@ -18,11 +19,12 @@ rule training_dataset: + "--compression-level " + str(tensor_compression_level) + " " - "--training-dataset-file {input.training_dataset} " + # "--training-dataset-file {input.training_dataset} " "{input.config} " "{output.input_tensor} " "{output.covariates} " - "{output.y}" + "{output.y} " + "{output.sample_ids}" ) @@ -30,11 +32,11 @@ rule training_dataset_pickle: input: "{phenotype}/deeprvat/hpopt_config.yaml", output: - "{phenotype}/deeprvat/training_dataset.pkl", + temp("{phenotype}/deeprvat/training_dataset.pkl"), threads: 1 resources: mem_mb=40000, # lambda wildcards, attempt: 38000 + 12000 * attempt - load=16000, + load=40000, shell: ( "deeprvat_train make-dataset " @@ -42,4 +44,4 @@ rule training_dataset_pickle: "--training-dataset-file {output} " "{input} " "dummy dummy dummy" - ) \ No newline at end of file + ) diff --git a/tests/deeprvat/test_train.py b/tests/deeprvat/test_train.py index 7fbcbbfa..60d6605e 100644 --- a/tests/deeprvat/test_train.py +++ b/tests/deeprvat/test_train.py @@ -60,6 +60,7 @@ def make_multipheno_data(): ) data[p]["input_tensor"] = data[p]["input_tensor_zarr"][:] data[p]["samples"] = {"train": np.arange(data[p]["y"].shape[0])} + data[p]["sample_indices"] = np.arange(data[p]["y"].shape[0]) return data @@ -145,9 +146,10 @@ def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): yaml.dump(config, f) # This is the function we want to test - input_tensor_out_file = str(tmp_path / "input_tensor.zarr") - covariates_out_file = str(tmp_path / "covariates.zarr") - y_out_file = str(tmp_path / "y.zarr") + input_tensor_out_file = tmp_path / "input_tensor.zarr" + covariates_out_file = tmp_path / "covariates.zarr" + y_out_file = tmp_path / "y.zarr" + samples_out_file = tmp_path / "sample_ids.zarr" logger.info("Constructing test dataset") test_ds = make_dataset_( False, @@ -158,6 +160,7 @@ def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): input_tensor_out_file, covariates_out_file, y_out_file, + samples_out_file, ) # Load the data it output