From 01d8d55f4972a168cc66d2ee904c7d1d12aaddf6 Mon Sep 17 00:00:00 2001 From: Munzlinger Date: Tue, 28 Nov 2023 16:19:46 +0100 Subject: [PATCH] added documentation --- deeprvat/deeprvat/models.py | 129 ++++++++++++++++++++++++++++++++++-- deeprvat/deeprvat/train.py | 116 +++++++++++++++++++++++++++++++- 2 files changed, 235 insertions(+), 10 deletions(-) diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 633bd63c..19d3f73c 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -41,8 +41,13 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any): else: return default - class BaseModel(pl.LightningModule): + """ + Base class containing functions that will be called by PyTorch Lightning in the + background by default. + """ + + def __init__( self, config: dict, @@ -53,6 +58,14 @@ def __init__( stage: str = "train", **kwargs, ): + """ + config: dict, representing the content of config.yaml + n_annotations: dict, containing the number of annotations used each phenotype + n_covariates: dict, containing the number of covariates used each phenotype + n_genes: dict, containing the number of genes used each phenotype + phenotypes: list, containing phenotypes used during training + stage: str, containing a prefix, pointing to the dataset the model is operating on + """ super().__init__() self.save_hyperparameters(config) self.save_hyperparameters(kwargs) @@ -74,7 +87,12 @@ def __init__( else: raise ValueError("Unknown objective_mode configuration parameter") + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Function used to setup an optimizer and scheduler by their + parameters which are specified in config + """ optimizer_config = self.hparams["optimizer"] optimizer_class = getattr(torch.optim, optimizer_config["type"]) optimizer = optimizer_class( @@ -99,10 +117,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer: else: return optimizer + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + """ + Function is called by the trainer during training and is expected to return + a loss from which we can compute backward passes. + """ + # calls DeepSet.forward() y_pred_by_pheno = self(batch) results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): + # compute mean distance in between ground truth and predicted score. results[name] = torch.mean( torch.stack( [ @@ -112,22 +138,31 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: ) ) self.log(f"{self.hparams.stage}_{name}", results[name]) - + # set loss from which we compute backward passes loss = results[self.hparams.metrics["loss"]] if torch.any(torch.isnan(loss)): raise RuntimeError("NaNs found in training loss") return loss def validation_step(self, batch: dict, batch_idx: int): + """ + During validation we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterwards as a whole. + """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} + def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] ): + """ + Evaluate all phenotype predictions in one go after accumulating them earlier + """ y_pred_by_pheno = dict() y_by_pheno = dict() for result in prediction_y: + # create a dict for each phenotype that includes all respective predictions pred = result["y_pred_by_pheno"] for pheno, ys in pred.items(): y_pred_by_pheno[pheno] = torch.cat( @@ -138,14 +173,16 @@ def validation_epoch_end( ys, ] ) - + # create a dict for each phenotype that includes the respective ground truth target = result["y_by_pheno"] for pheno, ys in target.items(): y_by_pheno[pheno] = torch.cat( [y_by_pheno.get(pheno, torch.tensor([], device=self.device)), ys] ) + # create a dict for each phenotype that stores the respective loss results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): results[name] = torch.mean( torch.stack( @@ -156,15 +193,25 @@ def validation_epoch_end( ) ) self.log(f"val_{name}", results[name]) - + # consider all metrics only store the most min/max in self.best_objective + # to determine if progress was made in the last training epoch. self.best_objective = self.objective_operation( self.best_objective, results[self.hparams.metrics["objective"]].item() ) + def test_step(self, batch: dict, batch_idx: int): + """ + During testing we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterwards as a whole. + """ return {"y_pred": self(batch), "y": batch["y"]} + def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): + """ + Evaluate all phenotype predictions in one go after accumulating them earlier + """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -180,8 +227,14 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): def configure_callbacks(self): return [ModelSummary()] - class DeepSetAgg(pl.LightningModule): + """ + class contains the gene impairment module used for burden computation. + Variants are fed through an embedding network Phi, to compute a variant embedding + The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. + Afterwards second network Rho, estimates the final gene impairment score. + All parameters of the gene impairment module are shared across genes and traits. + """ def __init__( self, n_annotations: int, @@ -196,6 +249,21 @@ def __init__( use_sigmoid: bool = False, reverse: bool = False, ): + """ + n_annotations: int, number of annotations + phi_layers: int, number of layers in Phi + phi_hidden_dim: int, internal dimensionality of linear layers in Phi + rho_layers: int, number of layers in Rho + rho_hidden_dim: int, internal dimensionality of linear layers in Rho + activation: str, activation function used, has to match its name in torch.nn + pool: str, invariant aggregation function used to aggregate gene variants + possiblities: max, sum + output_dim: int, number of burden scores + dropout: float, propability, by which some parameters are set to 0 + use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a + linear activation function to mimic association testing during training + reverse: booleon, to reverse the burden score. (used during association testing) + """ super().__init__() self.output_dim = output_dim @@ -205,6 +273,7 @@ def __init__( self.use_sigmoid = use_sigmoid self.reverse = reverse + # setup of Phi input_dim = n_annotations phi = [] for l in range(phi_layers): @@ -216,10 +285,12 @@ def __init__( input_dim = output_dim self.phi = nn.Sequential(OrderedDict(phi)) + # setup permutation-invariant aggregation function if pool not in ("sum", "max"): raise ValueError(f"Unknown pooling operation {pool}") self.pool = pool + # setup of Rho rho = [] for l in range(rho_layers - 1): output_dim = rho_hidden_dim @@ -231,8 +302,12 @@ def __init__( rho.append( (f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim)) ) + # No final non-linear activation function to keep the relationship between + # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) + # reverse burden score during association testing if model predicts in negative space. + # compare associate.py, reverse_models() for further detail. def set_reverse(self, reverse: bool = True): self.reverse = reverse @@ -252,8 +327,13 @@ def forward(self, x): x = torch.sigmoid(x) return x - class DeepSet(BaseModel): + """ + Wrapper class for burden computation, that also does phenotype prediction. + It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions + like "training_step" or "validation_epoch_end" can be found. + Those functions are called in background by default. + """ def __init__( self, config: dict, @@ -266,6 +346,18 @@ def __init__( reverse: bool = False, **kwargs, ): + """ + config: dict, representing the content of config.yaml + n_annotations: dict, containing the number of annotations used each phenotype + n_covariates: dict, containing the number of covariates used each phenotype + n_genes: dict, containing the number of genes used each phenotype + phenotypes: list, containing phenotypes used during training + agg_model: pl.LightningModule / nn.Module, model used for burden computation + if module isnt given, it will be initialized + use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a + linear activation function to mimic association testing during training + reverse: booleon, to reverse the burden score. (used during association testing) + """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs ) @@ -294,6 +386,13 @@ def __init__( ) self.agg_model.train(False if self.hparams.stage == "val" else True) + # dict of various linear layers used for phenotype prediction. + # Returns can be tested against ground truth data. + # self.agg_model compresses a batch + # from: samples x genes x annotations x variants; + # to: samples x genes + # afterwards genes are concatenated with covariates + # to: samples x (genes + covariates) self.gene_pheno = nn.ModuleDict( { pheno: nn.Linear( @@ -302,8 +401,23 @@ def __init__( for pheno in self.hparams.phenotypes } ) - + def forward(self, batch): + """ + batch: dict of phenotypes, each containing the following keys: + - indices: tensor, + indices for the underlying dataframe + + - covariates: tensor, + covariates of samples e.g. age, + content: samples x covariates + + - rare_variant_annotations: tensor, + actual annotated genomic variants, + content: samples x genes x annotations x variants + - y: tensor, + actual phenotypes (ground truth data) + """ result = dict() for pheno, this_batch in batch.items(): x = this_batch["rare_variant_annotations"] @@ -380,3 +494,4 @@ def forward(self, batch): x = torch.cat((batch["covariates"], x), dim=1) x = self.gene_pheno(x).squeeze(dim=1) # samples return x + diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 8ee0f465..1b478482 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -218,6 +218,10 @@ def make_dataset( class MultiphenoDataset(Dataset): + """ + class used to structure the data and present a __getitem__ function to + the dataloader, that will be used to load batches into the model + """ def __init__( self, # input_tensor: zarr.core.Array, @@ -231,7 +235,13 @@ def __init__( # samples: Optional[Union[slice, np.ndarray]] = None, # genes: Optional[Union[slice, np.ndarray]] = None ): - "Initialization" + """ + data: dict, underlying dataframe from which data is structured into batches + min_variant_count: int, minimum number of variants available each gene. + batch_size: int, number of samples / individuals available in one batch + split: str, containing a prefix, pointing to the dataset the model is operating on + cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + """ super().__init__() self.data = data @@ -263,6 +273,8 @@ def __init__( self.total_samples = sum([s.shape[0] for s in self.samples.values()]) self.batch_size = batch_size + # index all samples and categorize them by phenotype, such that we + # get a dataframe repreenting a chain of phenotypes self.sample_order = pd.DataFrame( { "phenotype": itertools.chain( @@ -274,6 +286,7 @@ def __init__( {"phenotype": pd.api.types.CategoricalDtype()} ) self.sample_order = self.sample_order.sample(n=self.total_samples) # shuffle + # phenotype specific index; e.g. 7. element total, 2. element for phenotype "Urate" self.sample_order["index"] = self.sample_order.groupby("phenotype").cumcount() def __len__(self): @@ -294,6 +307,7 @@ def __getitem__(self, index): result = dict() for pheno, df in samples_by_pheno: + # get phenotype specific sub-index idx = df["index"].to_numpy() annotations = ( @@ -310,8 +324,13 @@ def __getitem__(self, index): } return result + def subset_samples(self): + """ + Function used to sort out samples which contain real phenotypes with NaN values and + samples with less variants each gene then what we specify as a minimum. + """ for pheno, pheno_data in self.data.items(): # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all @@ -335,6 +354,9 @@ def subset_samples(self): class MultiphenoBaggingData(pl.LightningDataModule): + """ + Preprocess the underlying dataframe, to then load it into a dataset object + """ def __init__( self, data: Dict[str, Dict], @@ -346,6 +368,17 @@ def __init__( num_workers: Optional[int] = 0, cache_tensors: bool = False, ): + """ + data: dict, underlying dataframe from which data is structured into batches + train_proportion: float, percentage by which data is devided into training / validation split + sample_with_replacement: bool, if true a sample of a can be selected multiple times in one epoch. + min_variant_count: int minimum number of variants available each gene. + upsampling_factor: int, percentual factor by which we want to upsample data; >= 1; + however, not yet implemented for multi-phenotype training! + batch_size: int, number of samples / individuals available in one batch + num_workers: int, number of workers which simultaneously putting data into RAM + cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + """ logger.info("Intializing datamodule") super().__init__() @@ -391,11 +424,14 @@ def __init__( else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() + # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( samples, size=n_train_samples, replace=sample_with_replacement ) ) + # samples which are not part of train_samples, but in samples + # are validation samples. pheno_data["samples"] = { "train": train_samples, "val": np.setdiff1d(samples, train_samples), @@ -408,8 +444,12 @@ def __init__( "num_workers", "cache_tensors", ) - + def upsample(self) -> np.ndarray: + """ + does not work at the moment for multi-phenotype training. Needs some minor changes + to make it work again + """ unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -432,7 +472,13 @@ def upsample(self) -> np.ndarray: self.samples = upsampled_indices + def train_dataloader(self): + """ + trainning samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating training dataloader " f"with batch size {self.hparams.batch_size}" @@ -447,8 +493,14 @@ def train_dataloader(self): return DataLoader( dataset, batch_size=None, num_workers=self.hparams.num_workers ) + def val_dataloader(self): + """ + validation samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating validation dataloader " f"with batch size {self.hparams.batch_size}" @@ -474,10 +526,39 @@ def run_bagging( trial_id: Optional[int] = None, debug: bool = False, ) -> Optional[float]: + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna + + config: dict, build from yaml, which serves for configuration + data: dict of phenotypes, each containing a dict, storing the underlying data + log_dir: str, path to were logs are written + checkpoint_file: str, path to where the weights of the trained model should be saved + trial: optuna object, generated from the study + trial_id: int, current trial in range n_trials + debug: bool, use a strongly reduced dataframe during training + """ + + # if hyperparameter optimization is performed (train(); hpopt_file != None) if trial is not None: if trial_id is not None: + # differentiate various repeats in their individual optimization trial.set_user_attr("user_id", trial_id) + # Parameters set in config can be used to indicate hyperparameter optimization. + # Such cases can be spotted by the following exemplary pattern: + # + # phi_hidden_dim: 20 + # hparam: + # type: int + # args: + # - 16 + # - 64 + # kwargs: + # step: 16 + # + # this line should be translated into: + # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) + # and afterwards replace the respective area in config to set the suggestion. config["model"]["config"] = suggest_hparams(config["model"]["config"], trial) logger.info("Model hyperparameters this trial:") pprint(config["model"]["config"]) @@ -487,6 +568,8 @@ def run_bagging( with open(config_out, "w") as f: yaml.dump(config, f) + # in practice we only train a single bag, as there are + # theoretical reasons to omit bagging w.r.t. association testing n_bags = config["training"]["n_bags"] if not debug else 3 train_proportion = config["training"].get("train_proportion", None) logger.info(f"Training {n_bags} bagged models") @@ -516,13 +599,15 @@ def run_bagging( "cache_tensors", ) } + # load data into the required formate dm = MultiphenoBaggingData( this_data, train_proportion, **dm_kwargs, **config["training"]["dataloader_config"], ) - + + # setup the model architecture as specified in config model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class( config=config["model"]["config"], @@ -540,6 +625,8 @@ def run_bagging( objective = "val_" + config["model"]["config"]["metrics"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] + + # to prune underperforming trials we enable a pruning strategy that can be set in config if "early_stopping" in config: callbacks.append( EarlyStopping(monitor=objective, **config["early_stopping"]) @@ -549,14 +636,17 @@ def run_bagging( config["pl_trainer"]["min_epochs"] = 10 config["pl_trainer"]["max_epochs"] = 20 + # initialize trainer, which will call background functionality trainer = pl.Trainer( logger=tb_logger, callbacks=callbacks, **config.get("pl_trainer", {}) ) while True: try: + # actual training of the model trainer.fit(model, dm) except RuntimeError as e: + # if batch_size is choosen to big, it will be reduced until it fits the GPU logging.error(f"Caught RuntimeError: {e}") if str(e).find("CUDA out of memory") != -1: if dm.hparams.batch_size > 4: @@ -655,6 +745,25 @@ def train( log_dir: str, hpopt_file: str, ): + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna + + debug: bool, use a strongly reduced dataframe during training + training_gene_file: str, path to a pickle file specifying on which genes training should be executed + n_trials: int, number of trials to be performed by the given setting + trial_id: int, current trial in range n_trials + sample_file: str, path to a pickle file, which specifies which samples should be considered during training + phenotype: array of phenotypes, containing an array of paths, where the underlying data is stored: + 1. str containing the phenotype name + 2. annotated gene variants as zarr file + 3. covariates each sample as zarr file + 4. ground truth phenotypes as zarr file + config_file: str, path to a yaml file, which serves for configuration + log_dir: str, path to were logs are written + hpopt_file: str, path to where a .db file should be created in which the results of hyperparameter + optimization are stored + """ + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified") @@ -682,6 +791,7 @@ def train( samples = slice(None) data = dict() + # pack underlying data into a single dict that can be passed to downstream functions for pheno, input_tensor_file, covariates_file, y_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open(input_tensor_file, mode="r")