From 01d8d55f4972a168cc66d2ee904c7d1d12aaddf6 Mon Sep 17 00:00:00 2001 From: Munzlinger Date: Tue, 28 Nov 2023 16:19:46 +0100 Subject: [PATCH 1/4] added documentation --- deeprvat/deeprvat/models.py | 129 ++++++++++++++++++++++++++++++++++-- deeprvat/deeprvat/train.py | 116 +++++++++++++++++++++++++++++++- 2 files changed, 235 insertions(+), 10 deletions(-) diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 633bd63c..19d3f73c 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -41,8 +41,13 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any): else: return default - class BaseModel(pl.LightningModule): + """ + Base class containing functions that will be called by PyTorch Lightning in the + background by default. + """ + + def __init__( self, config: dict, @@ -53,6 +58,14 @@ def __init__( stage: str = "train", **kwargs, ): + """ + config: dict, representing the content of config.yaml + n_annotations: dict, containing the number of annotations used each phenotype + n_covariates: dict, containing the number of covariates used each phenotype + n_genes: dict, containing the number of genes used each phenotype + phenotypes: list, containing phenotypes used during training + stage: str, containing a prefix, pointing to the dataset the model is operating on + """ super().__init__() self.save_hyperparameters(config) self.save_hyperparameters(kwargs) @@ -74,7 +87,12 @@ def __init__( else: raise ValueError("Unknown objective_mode configuration parameter") + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Function used to setup an optimizer and scheduler by their + parameters which are specified in config + """ optimizer_config = self.hparams["optimizer"] optimizer_class = getattr(torch.optim, optimizer_config["type"]) optimizer = optimizer_class( @@ -99,10 +117,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer: else: return optimizer + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + """ + Function is called by the trainer during training and is expected to return + a loss from which we can compute backward passes. + """ + # calls DeepSet.forward() y_pred_by_pheno = self(batch) results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): + # compute mean distance in between ground truth and predicted score. results[name] = torch.mean( torch.stack( [ @@ -112,22 +138,31 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: ) ) self.log(f"{self.hparams.stage}_{name}", results[name]) - + # set loss from which we compute backward passes loss = results[self.hparams.metrics["loss"]] if torch.any(torch.isnan(loss)): raise RuntimeError("NaNs found in training loss") return loss def validation_step(self, batch: dict, batch_idx: int): + """ + During validation we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterwards as a whole. + """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} + def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] ): + """ + Evaluate all phenotype predictions in one go after accumulating them earlier + """ y_pred_by_pheno = dict() y_by_pheno = dict() for result in prediction_y: + # create a dict for each phenotype that includes all respective predictions pred = result["y_pred_by_pheno"] for pheno, ys in pred.items(): y_pred_by_pheno[pheno] = torch.cat( @@ -138,14 +173,16 @@ def validation_epoch_end( ys, ] ) - + # create a dict for each phenotype that includes the respective ground truth target = result["y_by_pheno"] for pheno, ys in target.items(): y_by_pheno[pheno] = torch.cat( [y_by_pheno.get(pheno, torch.tensor([], device=self.device)), ys] ) + # create a dict for each phenotype that stores the respective loss results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): results[name] = torch.mean( torch.stack( @@ -156,15 +193,25 @@ def validation_epoch_end( ) ) self.log(f"val_{name}", results[name]) - + # consider all metrics only store the most min/max in self.best_objective + # to determine if progress was made in the last training epoch. self.best_objective = self.objective_operation( self.best_objective, results[self.hparams.metrics["objective"]].item() ) + def test_step(self, batch: dict, batch_idx: int): + """ + During testing we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterwards as a whole. + """ return {"y_pred": self(batch), "y": batch["y"]} + def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): + """ + Evaluate all phenotype predictions in one go after accumulating them earlier + """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -180,8 +227,14 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): def configure_callbacks(self): return [ModelSummary()] - class DeepSetAgg(pl.LightningModule): + """ + class contains the gene impairment module used for burden computation. + Variants are fed through an embedding network Phi, to compute a variant embedding + The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. + Afterwards second network Rho, estimates the final gene impairment score. + All parameters of the gene impairment module are shared across genes and traits. + """ def __init__( self, n_annotations: int, @@ -196,6 +249,21 @@ def __init__( use_sigmoid: bool = False, reverse: bool = False, ): + """ + n_annotations: int, number of annotations + phi_layers: int, number of layers in Phi + phi_hidden_dim: int, internal dimensionality of linear layers in Phi + rho_layers: int, number of layers in Rho + rho_hidden_dim: int, internal dimensionality of linear layers in Rho + activation: str, activation function used, has to match its name in torch.nn + pool: str, invariant aggregation function used to aggregate gene variants + possiblities: max, sum + output_dim: int, number of burden scores + dropout: float, propability, by which some parameters are set to 0 + use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a + linear activation function to mimic association testing during training + reverse: booleon, to reverse the burden score. (used during association testing) + """ super().__init__() self.output_dim = output_dim @@ -205,6 +273,7 @@ def __init__( self.use_sigmoid = use_sigmoid self.reverse = reverse + # setup of Phi input_dim = n_annotations phi = [] for l in range(phi_layers): @@ -216,10 +285,12 @@ def __init__( input_dim = output_dim self.phi = nn.Sequential(OrderedDict(phi)) + # setup permutation-invariant aggregation function if pool not in ("sum", "max"): raise ValueError(f"Unknown pooling operation {pool}") self.pool = pool + # setup of Rho rho = [] for l in range(rho_layers - 1): output_dim = rho_hidden_dim @@ -231,8 +302,12 @@ def __init__( rho.append( (f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim)) ) + # No final non-linear activation function to keep the relationship between + # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) + # reverse burden score during association testing if model predicts in negative space. + # compare associate.py, reverse_models() for further detail. def set_reverse(self, reverse: bool = True): self.reverse = reverse @@ -252,8 +327,13 @@ def forward(self, x): x = torch.sigmoid(x) return x - class DeepSet(BaseModel): + """ + Wrapper class for burden computation, that also does phenotype prediction. + It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions + like "training_step" or "validation_epoch_end" can be found. + Those functions are called in background by default. + """ def __init__( self, config: dict, @@ -266,6 +346,18 @@ def __init__( reverse: bool = False, **kwargs, ): + """ + config: dict, representing the content of config.yaml + n_annotations: dict, containing the number of annotations used each phenotype + n_covariates: dict, containing the number of covariates used each phenotype + n_genes: dict, containing the number of genes used each phenotype + phenotypes: list, containing phenotypes used during training + agg_model: pl.LightningModule / nn.Module, model used for burden computation + if module isnt given, it will be initialized + use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a + linear activation function to mimic association testing during training + reverse: booleon, to reverse the burden score. (used during association testing) + """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs ) @@ -294,6 +386,13 @@ def __init__( ) self.agg_model.train(False if self.hparams.stage == "val" else True) + # dict of various linear layers used for phenotype prediction. + # Returns can be tested against ground truth data. + # self.agg_model compresses a batch + # from: samples x genes x annotations x variants; + # to: samples x genes + # afterwards genes are concatenated with covariates + # to: samples x (genes + covariates) self.gene_pheno = nn.ModuleDict( { pheno: nn.Linear( @@ -302,8 +401,23 @@ def __init__( for pheno in self.hparams.phenotypes } ) - + def forward(self, batch): + """ + batch: dict of phenotypes, each containing the following keys: + - indices: tensor, + indices for the underlying dataframe + + - covariates: tensor, + covariates of samples e.g. age, + content: samples x covariates + + - rare_variant_annotations: tensor, + actual annotated genomic variants, + content: samples x genes x annotations x variants + - y: tensor, + actual phenotypes (ground truth data) + """ result = dict() for pheno, this_batch in batch.items(): x = this_batch["rare_variant_annotations"] @@ -380,3 +494,4 @@ def forward(self, batch): x = torch.cat((batch["covariates"], x), dim=1) x = self.gene_pheno(x).squeeze(dim=1) # samples return x + diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 8ee0f465..1b478482 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -218,6 +218,10 @@ def make_dataset( class MultiphenoDataset(Dataset): + """ + class used to structure the data and present a __getitem__ function to + the dataloader, that will be used to load batches into the model + """ def __init__( self, # input_tensor: zarr.core.Array, @@ -231,7 +235,13 @@ def __init__( # samples: Optional[Union[slice, np.ndarray]] = None, # genes: Optional[Union[slice, np.ndarray]] = None ): - "Initialization" + """ + data: dict, underlying dataframe from which data is structured into batches + min_variant_count: int, minimum number of variants available each gene. + batch_size: int, number of samples / individuals available in one batch + split: str, containing a prefix, pointing to the dataset the model is operating on + cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + """ super().__init__() self.data = data @@ -263,6 +273,8 @@ def __init__( self.total_samples = sum([s.shape[0] for s in self.samples.values()]) self.batch_size = batch_size + # index all samples and categorize them by phenotype, such that we + # get a dataframe repreenting a chain of phenotypes self.sample_order = pd.DataFrame( { "phenotype": itertools.chain( @@ -274,6 +286,7 @@ def __init__( {"phenotype": pd.api.types.CategoricalDtype()} ) self.sample_order = self.sample_order.sample(n=self.total_samples) # shuffle + # phenotype specific index; e.g. 7. element total, 2. element for phenotype "Urate" self.sample_order["index"] = self.sample_order.groupby("phenotype").cumcount() def __len__(self): @@ -294,6 +307,7 @@ def __getitem__(self, index): result = dict() for pheno, df in samples_by_pheno: + # get phenotype specific sub-index idx = df["index"].to_numpy() annotations = ( @@ -310,8 +324,13 @@ def __getitem__(self, index): } return result + def subset_samples(self): + """ + Function used to sort out samples which contain real phenotypes with NaN values and + samples with less variants each gene then what we specify as a minimum. + """ for pheno, pheno_data in self.data.items(): # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all @@ -335,6 +354,9 @@ def subset_samples(self): class MultiphenoBaggingData(pl.LightningDataModule): + """ + Preprocess the underlying dataframe, to then load it into a dataset object + """ def __init__( self, data: Dict[str, Dict], @@ -346,6 +368,17 @@ def __init__( num_workers: Optional[int] = 0, cache_tensors: bool = False, ): + """ + data: dict, underlying dataframe from which data is structured into batches + train_proportion: float, percentage by which data is devided into training / validation split + sample_with_replacement: bool, if true a sample of a can be selected multiple times in one epoch. + min_variant_count: int minimum number of variants available each gene. + upsampling_factor: int, percentual factor by which we want to upsample data; >= 1; + however, not yet implemented for multi-phenotype training! + batch_size: int, number of samples / individuals available in one batch + num_workers: int, number of workers which simultaneously putting data into RAM + cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + """ logger.info("Intializing datamodule") super().__init__() @@ -391,11 +424,14 @@ def __init__( else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() + # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( samples, size=n_train_samples, replace=sample_with_replacement ) ) + # samples which are not part of train_samples, but in samples + # are validation samples. pheno_data["samples"] = { "train": train_samples, "val": np.setdiff1d(samples, train_samples), @@ -408,8 +444,12 @@ def __init__( "num_workers", "cache_tensors", ) - + def upsample(self) -> np.ndarray: + """ + does not work at the moment for multi-phenotype training. Needs some minor changes + to make it work again + """ unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -432,7 +472,13 @@ def upsample(self) -> np.ndarray: self.samples = upsampled_indices + def train_dataloader(self): + """ + trainning samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating training dataloader " f"with batch size {self.hparams.batch_size}" @@ -447,8 +493,14 @@ def train_dataloader(self): return DataLoader( dataset, batch_size=None, num_workers=self.hparams.num_workers ) + def val_dataloader(self): + """ + validation samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating validation dataloader " f"with batch size {self.hparams.batch_size}" @@ -474,10 +526,39 @@ def run_bagging( trial_id: Optional[int] = None, debug: bool = False, ) -> Optional[float]: + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna + + config: dict, build from yaml, which serves for configuration + data: dict of phenotypes, each containing a dict, storing the underlying data + log_dir: str, path to were logs are written + checkpoint_file: str, path to where the weights of the trained model should be saved + trial: optuna object, generated from the study + trial_id: int, current trial in range n_trials + debug: bool, use a strongly reduced dataframe during training + """ + + # if hyperparameter optimization is performed (train(); hpopt_file != None) if trial is not None: if trial_id is not None: + # differentiate various repeats in their individual optimization trial.set_user_attr("user_id", trial_id) + # Parameters set in config can be used to indicate hyperparameter optimization. + # Such cases can be spotted by the following exemplary pattern: + # + # phi_hidden_dim: 20 + # hparam: + # type: int + # args: + # - 16 + # - 64 + # kwargs: + # step: 16 + # + # this line should be translated into: + # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) + # and afterwards replace the respective area in config to set the suggestion. config["model"]["config"] = suggest_hparams(config["model"]["config"], trial) logger.info("Model hyperparameters this trial:") pprint(config["model"]["config"]) @@ -487,6 +568,8 @@ def run_bagging( with open(config_out, "w") as f: yaml.dump(config, f) + # in practice we only train a single bag, as there are + # theoretical reasons to omit bagging w.r.t. association testing n_bags = config["training"]["n_bags"] if not debug else 3 train_proportion = config["training"].get("train_proportion", None) logger.info(f"Training {n_bags} bagged models") @@ -516,13 +599,15 @@ def run_bagging( "cache_tensors", ) } + # load data into the required formate dm = MultiphenoBaggingData( this_data, train_proportion, **dm_kwargs, **config["training"]["dataloader_config"], ) - + + # setup the model architecture as specified in config model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class( config=config["model"]["config"], @@ -540,6 +625,8 @@ def run_bagging( objective = "val_" + config["model"]["config"]["metrics"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] + + # to prune underperforming trials we enable a pruning strategy that can be set in config if "early_stopping" in config: callbacks.append( EarlyStopping(monitor=objective, **config["early_stopping"]) @@ -549,14 +636,17 @@ def run_bagging( config["pl_trainer"]["min_epochs"] = 10 config["pl_trainer"]["max_epochs"] = 20 + # initialize trainer, which will call background functionality trainer = pl.Trainer( logger=tb_logger, callbacks=callbacks, **config.get("pl_trainer", {}) ) while True: try: + # actual training of the model trainer.fit(model, dm) except RuntimeError as e: + # if batch_size is choosen to big, it will be reduced until it fits the GPU logging.error(f"Caught RuntimeError: {e}") if str(e).find("CUDA out of memory") != -1: if dm.hparams.batch_size > 4: @@ -655,6 +745,25 @@ def train( log_dir: str, hpopt_file: str, ): + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna + + debug: bool, use a strongly reduced dataframe during training + training_gene_file: str, path to a pickle file specifying on which genes training should be executed + n_trials: int, number of trials to be performed by the given setting + trial_id: int, current trial in range n_trials + sample_file: str, path to a pickle file, which specifies which samples should be considered during training + phenotype: array of phenotypes, containing an array of paths, where the underlying data is stored: + 1. str containing the phenotype name + 2. annotated gene variants as zarr file + 3. covariates each sample as zarr file + 4. ground truth phenotypes as zarr file + config_file: str, path to a yaml file, which serves for configuration + log_dir: str, path to were logs are written + hpopt_file: str, path to where a .db file should be created in which the results of hyperparameter + optimization are stored + """ + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified") @@ -682,6 +791,7 @@ def train( samples = slice(None) data = dict() + # pack underlying data into a single dict that can be passed to downstream functions for pheno, input_tensor_file, covariates_file, y_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open(input_tensor_file, mode="r") From 89abc6b203d9690f534c229cd354097050703452 Mon Sep 17 00:00:00 2001 From: Munzlinger Date: Wed, 29 Nov 2023 11:58:33 +0100 Subject: [PATCH 2/4] added documentation and docstrings --- deeprvat/deeprvat/models.py | 230 +++++++++++++++++++++++++++--------- deeprvat/deeprvat/train.py | 134 +++++++++++++++------ 2 files changed, 270 insertions(+), 94 deletions(-) diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 19d3f73c..66f67f35 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -59,12 +59,16 @@ def __init__( **kwargs, ): """ - config: dict, representing the content of config.yaml - n_annotations: dict, containing the number of annotations used each phenotype - n_covariates: dict, containing the number of covariates used each phenotype - n_genes: dict, containing the number of genes used each phenotype - phenotypes: list, containing phenotypes used during training - stage: str, containing a prefix, pointing to the dataset the model is operating on + Initializes BaseModel. + + Args: + - config (dict): Represents the content of config.yaml. + - n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype. + - n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype. + - n_genes (Dict[str, int]): Contains the number of genes used for each phenotype. + - phenotypes (List[str]): Contains the phenotypes used during training. + - stage (str, optional): Contains a prefix indicating the dataset the model is operating on. Defaults to "train". + - **kwargs: Additional keyword arguments. """ super().__init__() self.save_hyperparameters(config) @@ -120,8 +124,19 @@ def configure_optimizers(self) -> torch.optim.Optimizer: def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: """ - Function is called by the trainer during training and is expected to return - a loss from which we can compute backward passes. + Function called by trainer during training and returns the loss used + to update weights and biases. + + Args: + - batch (dict): A dictionary containing the batch data. + - batch_idx (int): The index of the current batch. + + Returns: + - torch.Tensor: The loss value computed to update weights and biases + based on the predictions. + + Raises: + - RuntimeError: If NaNs are found in the training loss. """ # calls DeepSet.forward() y_pred_by_pheno = self(batch) @@ -148,6 +163,14 @@ def validation_step(self, batch: dict, batch_idx: int): """ During validation we do not compute backward passes, such that we can accumulate phenotype predictions and evaluate them afterwards as a whole. + + Args: + - batch (dict): A dictionary containing the validation batch data. + - batch_idx (int): The index of the current validation batch. + + Returns: + - dict: A dictionary containing phenotype predictions ("y_pred_by_pheno") + and corresponding ground truth values ("y_by_pheno"). """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} @@ -155,10 +178,14 @@ def validation_step(self, batch: dict, batch_idx: int): def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] - ): + ): + """ + Evaluate accumulated phenotype predictions at the end of the validation epoch. + + Args: + - prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. """ - Evaluate all phenotype predictions in one go after accumulating them earlier - """ y_pred_by_pheno = dict() y_by_pheno = dict() for result in prediction_y: @@ -204,13 +231,25 @@ def test_step(self, batch: dict, batch_idx: int): """ During testing we do not compute backward passes, such that we can accumulate phenotype predictions and evaluate them afterwards as a whole. + + Args: + - batch (dict): A dictionary containing the validation batch data. + - batch_idx (int): The index of the current validation batch. + + Returns: + - dict: A dictionary containing phenotype predictions ("y_pred") + and corresponding ground truth values ("y"). """ return {"y_pred": self(batch), "y": batch["y"]} def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): """ - Evaluate all phenotype predictions in one go after accumulating them earlier + Evaluate accumulated phenotype predictions at the end of the testing epoch. + + Args: + - prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the testing process. """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -250,19 +289,20 @@ def __init__( reverse: bool = False, ): """ - n_annotations: int, number of annotations - phi_layers: int, number of layers in Phi - phi_hidden_dim: int, internal dimensionality of linear layers in Phi - rho_layers: int, number of layers in Rho - rho_hidden_dim: int, internal dimensionality of linear layers in Rho - activation: str, activation function used, has to match its name in torch.nn - pool: str, invariant aggregation function used to aggregate gene variants - possiblities: max, sum - output_dim: int, number of burden scores - dropout: float, propability, by which some parameters are set to 0 - use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a - linear activation function to mimic association testing during training - reverse: booleon, to reverse the burden score. (used during association testing) + Initializes the DeepSetAgg module. + + Args: + - n_annotations (int): Number of annotations. + - phi_layers (int): Number of layers in Phi. + - phi_hidden_dim (int): Internal dimensionality of linear layers in Phi. + - rho_layers (int): Number of layers in Rho. + - rho_hidden_dim (int): Internal dimensionality of linear layers in Rho. + - activation (str): Activation function used; should match its name in torch.nn. + - pool (str): Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'. + - output_dim (int, optional): Number of burden scores. Defaults to 1. + - dropout (Optional[float], optional): Probability by which some parameters are set to 0. + - use_sigmoid (bool, optional): Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False. + - reverse (bool, optional): Whether to reverse the burden score (used during association testing). Defaults to False. """ super().__init__() @@ -306,12 +346,29 @@ def __init__( # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) - # reverse burden score during association testing if model predicts in negative space. - # compare associate.py, reverse_models() for further detail. + def set_reverse(self, reverse: bool = True): + """ + reverse burden score during association testing if model predicts in negative space. + + Args: + - reverse (bool, optional): Indicates whether the 'reverse' attribute should be set to True or False. + Defaults to True. + Note: + - Compare associate.py, reverse_models() for further detail. + """ self.reverse = reverse def forward(self, x): + """ + Perform forward pass through the model. + + Args: + - x (tensor): Batched input data + + Returns: + - tensor: Burden scores + """ x = self.phi(x.permute((0, 1, 3, 2))) # x.shape = samples x genes x variants x phi_latent if self.pool == "sum": @@ -320,7 +377,7 @@ def forward(self, x): x = torch.max(x, dim=2).values # Now x.shape = samples x genes x phi_latent x = self.rho(x) - # x.shape = samples x genes x rho_latent + # x.shape = samples x genes x 1 if self.reverse: x = -x if self.use_sigmoid: @@ -346,17 +403,22 @@ def __init__( reverse: bool = False, **kwargs, ): + """ - config: dict, representing the content of config.yaml - n_annotations: dict, containing the number of annotations used each phenotype - n_covariates: dict, containing the number of covariates used each phenotype - n_genes: dict, containing the number of genes used each phenotype - phenotypes: list, containing phenotypes used during training - agg_model: pl.LightningModule / nn.Module, model used for burden computation - if module isnt given, it will be initialized - use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a - linear activation function to mimic association testing during training - reverse: booleon, to reverse the burden score. (used during association testing) + Initialize the DeepSet model. + + Args: + - config (dict): Containing the content of config.yaml. + - n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype. + - n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype. + - n_genes (Dict[str, int]): Contains the number of genes used for each phenotype. + - phenotypes (List[str]): Contains the phenotypes used during training. + - agg_model (Optional[pl.LightningModule / nn.Module]): Model used for burden computation. If not provided, + it will be initialized. + - use_sigmoid (bool): Determines if burden scores should be projected to [0, 1]. Acts as a linear activation + function to mimic association testing during training. + - reverse (bool): Determines if the burden score should be reversed (used during association testing). + - **kwargs: Additional keyword arguments. """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs @@ -369,6 +431,9 @@ def __init__( pool = get_hparam(self, "pool", "sum") dropout = get_hparam(self, "dropout", None) + # self.agg_model compresses a batch + # from: samples x genes x annotations x variants + # to: samples x genes if agg_model is not None: self.agg_model = agg_model else: @@ -385,14 +450,12 @@ def __init__( reverse=reverse, ) self.agg_model.train(False if self.hparams.stage == "val" else True) - - # dict of various linear layers used for phenotype prediction. - # Returns can be tested against ground truth data. - # self.agg_model compresses a batch - # from: samples x genes x annotations x variants; - # to: samples x genes # afterwards genes are concatenated with covariates # to: samples x (genes + covariates) + + + # dict of various linear layers used for phenotype prediction. + # Returns can be tested against ground truth data. self.gene_pheno = nn.ModuleDict( { pheno: nn.Linear( @@ -404,19 +467,18 @@ def __init__( def forward(self, batch): """ - batch: dict of phenotypes, each containing the following keys: - - indices: tensor, - indices for the underlying dataframe - - - covariates: tensor, - covariates of samples e.g. age, - content: samples x covariates - - - rare_variant_annotations: tensor, - actual annotated genomic variants, - content: samples x genes x annotations x variants - - y: tensor, - actual phenotypes (ground truth data) + Forward pass through the model. + + Args: + - batch (dict): Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): annotated genomic variants. + Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + + Returns: + - dict: Dictionary containing predicted phenotypes """ result = dict() for pheno, this_batch in batch.items(): @@ -432,7 +494,21 @@ def forward(self, batch): class LinearAgg(pl.LightningModule): + """ + To capture only linear effect, this model can be used as it only uses a single + linear layer without a non-linear activation function. + It still contains the gene impairment module used for burden computation. + """ + def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): + """ + Initialize the LinearAgg model. + + Args: + - n_annotations (int): Number of annotations. + - pool (str): Pooling method ("sum" or "max") to be used. + - output_dim (int, optional): Dimensionality of the output. Defaults to 1. + """ super().__init__() self.output_dim = output_dim @@ -442,6 +518,15 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): self.linear = nn.Linear(n_annotations, self.output_dim) def forward(self, x): + """ + Perform forward pass through the model. + + Args: + - x (tensor): Batched input data + + Returns: + - tensor: Burden scores + """ x = self.linear( x.permute((0, 1, 3, 2)) ) # x.shape = samples x genes x variants x output_dim @@ -454,6 +539,11 @@ def forward(self, x): class TwoLayer(BaseModel): + """ + Wrapper class to capture linear effects. Inherits parameters from BaseModel, + which is where Pytorch Lightning specific functions like "training_step" or + "validation_epoch_end" can be found. Those functions are called in background by default. + """ def __init__( self, config: dict, @@ -463,6 +553,18 @@ def __init__( agg_model: Optional[nn.Module] = None, **kwargs, ): + """ + Initializes the TwoLayer model. + + Args: + - config (dict): Represents the content of config.yaml. + - n_annotations (int): Number of annotations. + - n_covariates (int): Number of covariates. + - n_genes (int): Number of genes. + - agg_model (Optional[nn.Module]): Model used for burden computation. If not provided, + it will be initialized. + - **kwargs: Additional keyword arguments. + """ super().__init__(config, n_annotations, n_covariates, n_genes, **kwargs) logger.info("Initializing TwoLayer model with parameters:") @@ -488,6 +590,20 @@ def __init__( self.gene_pheno = nn.Linear(self.hparams.n_covariates + self.hparams.n_genes, 1) def forward(self, batch): + """ + Forward pass through the model. + + Args: + - batch (dict): Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): annotated genomic variants. + Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + + Returns: + - dict: Dictionary containing predicted phenotypes + """ # samples x genes x annotations x variants x = batch["rare_variant_annotations"] x = self.agg_model(x).squeeze(dim=2) # samples x genes diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 5a06b924..4d465966 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -82,6 +82,19 @@ def make_dataset_( training_dataset_file: str = None, pickle_only: bool = False, ): + """ + Subfunction of make_dataset() + Convert a dataset file to the sparse format used for training and testing associations + + Args: + - config (Dict): Dictionary containing configuration parameters, build from YAML file + - debug (bool, optional): Use a strongly reduced dataframe + - training_dataset_file (str, optional): Path to the file in which training data is stored. + - pickle_only (bool, optional): If True, only store dataset as pickle file and return None. + + Returns: + - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing input_tensor, covariates, and target values. + """ n_phenotypes = config.get("n_phenotypes", None) if n_phenotypes is not None: if "seed_genes" in config: @@ -113,6 +126,7 @@ def make_dataset_( or training_dataset_file is None or not Path(training_dataset_file).is_file() ): + # load data into sparse data format ds = DenseGTDataset( gt_file=config["training_data"]["gt_file"], variant_file=config["training_data"]["variant_file"], @@ -190,6 +204,24 @@ def make_dataset( covariates_out_file: str, y_out_file: str, ): + """ + Uses function make_dataset_() to convert dataset to sparse format and stores the respective data + + Args: + - debug (bool): Use a strongly reduced dataframe + - pickle_only (bool): Flag to indicate whether only to save data using pickle + - compression_level (int): Level of compression in ZARR to be applied to training data. + - training_dataset_file (Optional[str]): Path to the file in which training data is stored. + - config_file (str): Path to a YAML file, which serves for configuration. + - input_tensor_out_file (str): Path to save the training data to. + - covariates_out_file (str): Path to save the covariates to. + - y_out_file (str): Path to save the ground truth data to. + + Returns: + - None + """ + + with open(config_file) as f: config = yaml.safe_load(f) @@ -231,12 +263,16 @@ def __init__( # genes: Optional[Union[slice, np.ndarray]] = None ): """ - data: dict, underlying dataframe from which data is structured into batches - min_variant_count: int, minimum number of variants available each gene. - batch_size: int, number of samples / individuals available in one batch - split: str, containing a prefix, pointing to the dataset the model is operating on - cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + Initialize the MultiphenoDataset. + + Args: + - data (Dict[str, Dict]): Underlying dataframe from which data is structured into batches. + - min_variant_count (int): Minimum number of variants available for each gene. + - batch_size (int): Number of samples/individuals available in one batch. + - split (str, optional): Contains a prefix indicating the dataset the model operates on. Defaults to "train". + - cache_tensors (bool, optional): Indicates if samples have been pre-loaded or need to be extracted from zarr. """ + super().__init__() self.data = data @@ -364,15 +400,17 @@ def __init__( cache_tensors: bool = False, ): """ - data: dict, underlying dataframe from which data is structured into batches - train_proportion: float, percentage by which data is devided into training / validation split - sample_with_replacement: bool, if true a sample of a can be selected multiple times in one epoch. - min_variant_count: int minimum number of variants available each gene. - upsampling_factor: int, percentual factor by which we want to upsample data; >= 1; - however, not yet implemented for multi-phenotype training! - batch_size: int, number of samples / individuals available in one batch - num_workers: int, number of workers which simultaneously putting data into RAM - cache_tensors: bool, indicates if samples have been pre-loaded or need to be extracted from zarr + Initialize the MultiphenoBaggingData. + + Args: + - data (Dict[str, Dict]): Underlying dataframe from which data structured into batches. + - train_proportion (float): Percentage by which data is divided into training/validation split. + - sample_with_replacement (bool, optional): If True, a sample can be selected multiple times in one epoch. Defaults to True. + - min_variant_count (int, optional): Minimum number of variants available for each gene. Defaults to 1. + - upsampling_factor (int, optional): Percentual factor by which to upsample data; >= 1. Defaults to 1. + - batch_size (Optional[int], optional): Number of samples/individuals available in one batch. Defaults to None. + - num_workers (Optional[int], optional): Number of workers simultaneously putting data into RAM. Defaults to 0. + - cache_tensors (bool, optional): Indicates if samples have been pre-loaded or need to be extracted from zarr. Defaults to False. """ logger.info("Intializing datamodule") @@ -522,15 +560,19 @@ def run_bagging( debug: bool = False, ) -> Optional[float]: """ - Main function called during training. Also used for trial pruning and sampling new parameters in optuna - - config: dict, build from yaml, which serves for configuration - data: dict of phenotypes, each containing a dict, storing the underlying data - log_dir: str, path to were logs are written - checkpoint_file: str, path to where the weights of the trained model should be saved - trial: optuna object, generated from the study - trial_id: int, current trial in range n_trials - debug: bool, use a strongly reduced dataframe during training + Main function called during training. Also used for trial pruning and sampling new parameters in optuna. + + Args: + - config (Dict): Dictionary containing configuration parameters, build from YAML file + - data (Dict[str, Dict]): Dict of phenotypes, each containing a dict storing the underlying data. + - log_dir (str): Path to where logs are written. + - checkpoint_file (Optional[str]): Path to where the weights of the trained model should be saved. + - trial (Optional[optuna.trial.Trial]): Optuna object generated from the study. + - trial_id (Optional[int]): Current trial in range n_trials. + - debug (bool): Use a strongly reduced dataframe + + Returns: + - Optional[float]: computes the lowest scores of all loss metrics and returns their average """ # if hyperparameter optimization is performed (train(); hpopt_file != None) @@ -743,22 +785,26 @@ def train( """ Main function called during training. Also used for trial pruning and sampling new parameters in optuna - debug: bool, use a strongly reduced dataframe during training - training_gene_file: str, path to a pickle file specifying on which genes training should be executed - n_trials: int, number of trials to be performed by the given setting - trial_id: int, current trial in range n_trials - sample_file: str, path to a pickle file, which specifies which samples should be considered during training - phenotype: array of phenotypes, containing an array of paths, where the underlying data is stored: - 1. str containing the phenotype name - 2. annotated gene variants as zarr file - 3. covariates each sample as zarr file - 4. ground truth phenotypes as zarr file - config_file: str, path to a yaml file, which serves for configuration - log_dir: str, path to were logs are written - hpopt_file: str, path to where a .db file should be created in which the results of hyperparameter - optimization are stored + Args: + - debug (bool): Use a strongly reduced dataframe + - training_gene_file (Optional[str]): Path to a pickle file specifying on which genes training should be executed. + - n_trials (int): Number of trials to be performed by the given setting. + - trial_id (Optional[int]): Current trial in range n_trials. + - sample_file (Optional[str]): Path to a pickle file specifying which samples should be considered during training. + - phenotype (Tuple[Tuple[str, str, str, str]]): Array of phenotypes, containing an array of paths where the underlying data is stored: + - str: Phenotype name + - str: Annotated gene variants as zarr file + - str: Covariates each sample as zarr file + - str: Ground truth phenotypes as zarr file + - config_file (str): Path to a YAML file, which serves for configuration. + - log_dir (str): Path to where logs are stored. + - hpopt_file (str): Path to where a .db file should be created in which the results of hyperparameter optimization are stored. + + Raises: + - ValueError: If no phenotype option is specified. """ + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified") @@ -879,6 +925,19 @@ def train( def best_training_run( debug: bool, log_dir: str, checkpoint_dir: str, hpopt_db: str, config_file_out: str ): + """ + Function to extract the best trial from an Optuna study and handle associated model checkpoints and configurations. + + Args: + - debug (bool): Use a strongly reduced dataframe + - log_dir (str): Path to where logs are stored. + - checkpoint_dir (str): Directory where checkpoints have been stored. + - hpopt_db (str): Path to the database file containing the Optuna study results. + - config_file_out (str): store a reduced + + Returns: + - None + """ study = optuna.load_study( study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}" ) @@ -902,6 +961,7 @@ def best_training_run( link_path.symlink_to(checkpoint.resolve(strict=True)) # Keep track of models marked to be dropped + # respective models are not used for downstream processing checkpoint_dropped = Path(str(checkpoint) + ".dropped") if checkpoint_dropped.is_file(): dropped_link_path = Path(checkpoint_dir) / f"bag_{k}.ckpt.dropped" From c412872a7a1db58e443f044c5d26f3de63586008 Mon Sep 17 00:00:00 2001 From: Magnus Wahlberg Date: Thu, 30 Nov 2023 10:03:11 +0100 Subject: [PATCH 3/4] Change format to sphinx --- deeprvat/deeprvat/models.py | 288 ++++++++++++++++++++---------------- deeprvat/deeprvat/train.py | 182 ++++++++++++++--------- 2 files changed, 267 insertions(+), 203 deletions(-) diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 66f67f35..7f1739bb 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -41,13 +41,13 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any): else: return default + class BaseModel(pl.LightningModule): """ Base class containing functions that will be called by PyTorch Lightning in the - background by default. + background by default. """ - def __init__( self, config: dict, @@ -61,14 +61,19 @@ def __init__( """ Initializes BaseModel. - Args: - - config (dict): Represents the content of config.yaml. - - n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype. - - n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype. - - n_genes (Dict[str, int]): Contains the number of genes used for each phenotype. - - phenotypes (List[str]): Contains the phenotypes used during training. - - stage (str, optional): Contains a prefix indicating the dataset the model is operating on. Defaults to "train". - - **kwargs: Additional keyword arguments. + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param stage: Contains a prefix indicating the dataset the model is operating on. Defaults to "train". (optional) + :type stage: str + :param kwargs: Additional keyword arguments. """ super().__init__() self.save_hyperparameters(config) @@ -91,7 +96,6 @@ def __init__( else: raise ValueError("Unknown objective_mode configuration parameter") - def configure_optimizers(self) -> torch.optim.Optimizer: """ Function used to setup an optimizer and scheduler by their @@ -121,22 +125,19 @@ def configure_optimizers(self) -> torch.optim.Optimizer: else: return optimizer - def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: """ - Function called by trainer during training and returns the loss used + Function called by trainer during training and returns the loss used to update weights and biases. - Args: - - batch (dict): A dictionary containing the batch data. - - batch_idx (int): The index of the current batch. - - Returns: - - torch.Tensor: The loss value computed to update weights and biases - based on the predictions. + :param batch: A dictionary containing the batch data. + :type batch: dict + :param batch_idx: The index of the current batch. + :type batch_idx: int - Raises: - - RuntimeError: If NaNs are found in the training loss. + :returns: torch.Tensor: The loss value computed to update weights and biases + based on the predictions. + :raises RuntimeError: If NaNs are found in the training loss. """ # calls DeepSet.forward() y_pred_by_pheno = self(batch) @@ -161,30 +162,36 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: def validation_step(self, batch: dict, batch_idx: int): """ - During validation we do not compute backward passes, such that we can accumulate - phenotype predictions and evaluate them afterwards as a whole. + During validation, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. - Args: - - batch (dict): A dictionary containing the validation batch data. - - batch_idx (int): The index of the current validation batch. + :param batch: A dictionary containing the validation batch data. + :type batch: dict + :param batch_idx: The index of the current validation batch. + :type batch_idx: int - Returns: - - dict: A dictionary containing phenotype predictions ("y_pred_by_pheno") - and corresponding ground truth values ("y_by_pheno"). + :returns: dict: A dictionary containing phenotype predictions ("y_pred_by_pheno") + and corresponding ground truth values ("y_by_pheno"). """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} - def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] - ): + ): """ Evaluate accumulated phenotype predictions at the end of the validation epoch. - Args: - - prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions - and corresponding ground truth values obtained during the validation process. + This function takes a list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. It computes + various metrics based on these predictions and logs the results. + + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] + + :return: None + :rtype: None """ y_pred_by_pheno = dict() y_by_pheno = dict() @@ -226,30 +233,29 @@ def validation_epoch_end( self.best_objective, results[self.hparams.metrics["objective"]].item() ) - def test_step(self, batch: dict, batch_idx: int): """ - During testing we do not compute backward passes, such that we can accumulate - phenotype predictions and evaluate them afterwards as a whole. + During testing, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. - Args: - - batch (dict): A dictionary containing the validation batch data. - - batch_idx (int): The index of the current validation batch. + :param batch: A dictionary containing the testing batch data. + :type batch: dict + :param batch_idx: The index of the current testing batch. + :type batch_idx: int - Returns: - - dict: A dictionary containing phenotype predictions ("y_pred") - and corresponding ground truth values ("y"). + :returns: dict: A dictionary containing phenotype predictions ("y_pred") + and corresponding ground truth values ("y"). + :rtype: dict """ return {"y_pred": self(batch), "y": batch["y"]} - def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): """ Evaluate accumulated phenotype predictions at the end of the testing epoch. - Args: - - prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions - and corresponding ground truth values obtained during the testing process. + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the testing process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -266,14 +272,17 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): def configure_callbacks(self): return [ModelSummary()] + class DeepSetAgg(pl.LightningModule): """ - class contains the gene impairment module used for burden computation. - Variants are fed through an embedding network Phi, to compute a variant embedding - The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. - Afterwards second network Rho, estimates the final gene impairment score. + class contains the gene impairment module used for burden computation. + + Variants are fed through an embedding network Phi to compute a variant embedding. + The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. + Afterward, the second network Rho estimates the final gene impairment score. All parameters of the gene impairment module are shared across genes and traits. """ + def __init__( self, n_annotations: int, @@ -291,18 +300,28 @@ def __init__( """ Initializes the DeepSetAgg module. - Args: - - n_annotations (int): Number of annotations. - - phi_layers (int): Number of layers in Phi. - - phi_hidden_dim (int): Internal dimensionality of linear layers in Phi. - - rho_layers (int): Number of layers in Rho. - - rho_hidden_dim (int): Internal dimensionality of linear layers in Rho. - - activation (str): Activation function used; should match its name in torch.nn. - - pool (str): Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'. - - output_dim (int, optional): Number of burden scores. Defaults to 1. - - dropout (Optional[float], optional): Probability by which some parameters are set to 0. - - use_sigmoid (bool, optional): Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False. - - reverse (bool, optional): Whether to reverse the burden score (used during association testing). Defaults to False. + :param n_annotations: Number of annotations. + :type n_annotations: int + :param phi_layers: Number of layers in Phi. + :type phi_layers: int + :param phi_hidden_dim: Internal dimensionality of linear layers in Phi. + :type phi_hidden_dim: int + :param rho_layers: Number of layers in Rho. + :type rho_layers: int + :param rho_hidden_dim: Internal dimensionality of linear layers in Rho. + :type rho_hidden_dim: int + :param activation: Activation function used; should match its name in torch.nn. + :type activation: str + :param pool: Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'. + :type pool: str + :param output_dim: Number of burden scores. Defaults to 1. (optional) + :type output_dim: int + :param dropout: Probability by which some parameters are set to 0. (optional) + :type dropout: Optional[float] + :param use_sigmoid: Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False. (optional) + :type use_sigmoid: bool + :param reverse: Whether to reverse the burden score (used during association testing). Defaults to False. (optional) + :type reverse: bool """ super().__init__() @@ -342,32 +361,32 @@ def __init__( rho.append( (f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim)) ) - # No final non-linear activation function to keep the relationship between + # No final non-linear activation function to keep the relationship between # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) - def set_reverse(self, reverse: bool = True): """ - reverse burden score during association testing if model predicts in negative space. + Reverse burden score during association testing if the model predicts in negative space. + + :param reverse: Indicates whether the 'reverse' attribute should be set to True or False. + Defaults to True. + :type reverse: bool - Args: - - reverse (bool, optional): Indicates whether the 'reverse' attribute should be set to True or False. - Defaults to True. Note: - - Compare associate.py, reverse_models() for further detail. + Compare associate.py, reverse_models() for further detail """ self.reverse = reverse def forward(self, x): """ - Perform forward pass through the model. + Perform a forward pass through the model. - Args: - - x (tensor): Batched input data + :param x: Batched input data + :type x: tensor - Returns: - - tensor: Burden scores + :returns: Burden scores + :rtype: tensor """ x = self.phi(x.permute((0, 1, 3, 2))) # x.shape = samples x genes x variants x phi_latent @@ -384,13 +403,15 @@ def forward(self, x): x = torch.sigmoid(x) return x + class DeepSet(BaseModel): """ - Wrapper class for burden computation, that also does phenotype prediction. - It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions + Wrapper class for burden computation, that also does phenotype prediction. + It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions like "training_step" or "validation_epoch_end" can be found. Those functions are called in background by default. """ + def __init__( self, config: dict, @@ -403,22 +424,27 @@ def __init__( reverse: bool = False, **kwargs, ): - """ Initialize the DeepSet model. - Args: - - config (dict): Containing the content of config.yaml. - - n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype. - - n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype. - - n_genes (Dict[str, int]): Contains the number of genes used for each phenotype. - - phenotypes (List[str]): Contains the phenotypes used during training. - - agg_model (Optional[pl.LightningModule / nn.Module]): Model used for burden computation. If not provided, - it will be initialized. - - use_sigmoid (bool): Determines if burden scores should be projected to [0, 1]. Acts as a linear activation - function to mimic association testing during training. - - reverse (bool): Determines if the burden score should be reversed (used during association testing). - - **kwargs: Additional keyword arguments. + :param config: Containing the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[pl.LightningModule / nn.Module] + :param use_sigmoid: Determines if burden scores should be projected to [0, 1]. Acts as a linear activation + function to mimic association testing during training. + :type use_sigmoid: bool + :param reverse: Determines if the burden score should be reversed (used during association testing). + :type reverse: bool + :param kwargs: Additional keyword arguments. """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs @@ -431,7 +457,7 @@ def __init__( pool = get_hparam(self, "pool", "sum") dropout = get_hparam(self, "dropout", None) - # self.agg_model compresses a batch + # self.agg_model compresses a batch # from: samples x genes x annotations x variants # to: samples x genes if agg_model is not None: @@ -450,10 +476,9 @@ def __init__( reverse=reverse, ) self.agg_model.train(False if self.hparams.stage == "val" else True) - # afterwards genes are concatenated with covariates + # afterwards genes are concatenated with covariates # to: samples x (genes + covariates) - - + # dict of various linear layers used for phenotype prediction. # Returns can be tested against ground truth data. self.gene_pheno = nn.ModuleDict( @@ -464,21 +489,20 @@ def __init__( for pheno in self.hparams.phenotypes } ) - + def forward(self, batch): """ Forward pass through the model. - Args: - - batch (dict): Dictionary of phenotypes, each containing the following keys: + :param batch: Dictionary of phenotypes, each containing the following keys: - indices (tensor): Indices for the underlying dataframe. - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. - - rare_variant_annotations (tensor): annotated genomic variants. - Content: samples x genes x annotations x variants. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict - Returns: - - dict: Dictionary containing predicted phenotypes + :returns: Dictionary containing predicted phenotypes + :rtype: dict """ result = dict() for pheno, this_batch in batch.items(): @@ -496,18 +520,20 @@ def forward(self, batch): class LinearAgg(pl.LightningModule): """ To capture only linear effect, this model can be used as it only uses a single - linear layer without a non-linear activation function. - It still contains the gene impairment module used for burden computation. + linear layer without a non-linear activation function. + It still contains the gene impairment module used for burden computation. """ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): """ Initialize the LinearAgg model. - Args: - - n_annotations (int): Number of annotations. - - pool (str): Pooling method ("sum" or "max") to be used. - - output_dim (int, optional): Dimensionality of the output. Defaults to 1. + :param n_annotations: Number of annotations. + :type n_annotations: int + :param pool: Pooling method ("sum" or "max") to be used. + :type pool: str + :param output_dim: Dimensionality of the output. Defaults to 1. (optional) + :type output_dim: int """ super().__init__() @@ -519,13 +545,13 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): def forward(self, x): """ - Perform forward pass through the model. + Perform a forward pass through the model. - Args: - - x (tensor): Batched input data + :param x: Batched input data + :type x: tensor - Returns: - - tensor: Burden scores + :returns: Burden scores + :rtype: tensor """ x = self.linear( x.permute((0, 1, 3, 2)) @@ -540,10 +566,11 @@ def forward(self, x): class TwoLayer(BaseModel): """ - Wrapper class to capture linear effects. Inherits parameters from BaseModel, - which is where Pytorch Lightning specific functions like "training_step" or + Wrapper class to capture linear effects. Inherits parameters from BaseModel, + which is where Pytorch Lightning specific functions like "training_step" or "validation_epoch_end" can be found. Those functions are called in background by default. """ + def __init__( self, config: dict, @@ -556,14 +583,17 @@ def __init__( """ Initializes the TwoLayer model. - Args: - - config (dict): Represents the content of config.yaml. - - n_annotations (int): Number of annotations. - - n_covariates (int): Number of covariates. - - n_genes (int): Number of genes. - - agg_model (Optional[nn.Module]): Model used for burden computation. If not provided, - it will be initialized. - - **kwargs: Additional keyword arguments. + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Number of annotations. + :type n_annotations: int + :param n_covariates: Number of covariates. + :type n_covariates: int + :param n_genes: Number of genes. + :type n_genes: int + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[nn.Module] + :param kwargs: Additional keyword arguments. """ super().__init__(config, n_annotations, n_covariates, n_genes, **kwargs) @@ -593,16 +623,15 @@ def forward(self, batch): """ Forward pass through the model. - Args: - - batch (dict): Dictionary of phenotypes, each containing the following keys: + :param batch: Dictionary of phenotypes, each containing the following keys: - indices (tensor): Indices for the underlying dataframe. - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. - - rare_variant_annotations (tensor): annotated genomic variants. - Content: samples x genes x annotations x variants. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict - Returns: - - dict: Dictionary containing predicted phenotypes + :returns: Dictionary containing predicted phenotypes + :rtype: dict """ # samples x genes x annotations x variants x = batch["rare_variant_annotations"] @@ -610,4 +639,3 @@ def forward(self, batch): x = torch.cat((batch["covariates"], x), dim=1) x = self.gene_pheno(x).squeeze(dim=1) # samples return x - diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 4d465966..367e4576 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -86,14 +86,17 @@ def make_dataset_( Subfunction of make_dataset() Convert a dataset file to the sparse format used for training and testing associations - Args: - - config (Dict): Dictionary containing configuration parameters, build from YAML file - - debug (bool, optional): Use a strongly reduced dataframe - - training_dataset_file (str, optional): Path to the file in which training data is stored. - - pickle_only (bool, optional): If True, only store dataset as pickle file and return None. - - Returns: - - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing input_tensor, covariates, and target values. + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param debug: Use a strongly reduced dataframe (optional) + :type debug: bool + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: str + :param pickle_only: If True, only store dataset as pickle file and return None. (optional) + :type pickle_only: bool + + :returns: Tuple containing input_tensor, covariates, and target values. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] """ n_phenotypes = config.get("n_phenotypes", None) if n_phenotypes is not None: @@ -207,18 +210,24 @@ def make_dataset( """ Uses function make_dataset_() to convert dataset to sparse format and stores the respective data - Args: - - debug (bool): Use a strongly reduced dataframe - - pickle_only (bool): Flag to indicate whether only to save data using pickle - - compression_level (int): Level of compression in ZARR to be applied to training data. - - training_dataset_file (Optional[str]): Path to the file in which training data is stored. - - config_file (str): Path to a YAML file, which serves for configuration. - - input_tensor_out_file (str): Path to save the training data to. - - covariates_out_file (str): Path to save the covariates to. - - y_out_file (str): Path to save the ground truth data to. - - Returns: - - None + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param pickle_only: Flag to indicate whether only to save data using pickle + :type pickle_only: bool + :param compression_level: Level of compression in ZARR to be applied to training data. + :type compression_level: int + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: Optional[str] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param input_tensor_out_file: Path to save the training data to. + :type input_tensor_out_file: str + :param covariates_out_file: Path to save the covariates to. + :type covariates_out_file: str + :param y_out_file: Path to save the ground truth data to. + :type y_out_file: str + + :returns: None """ @@ -265,12 +274,16 @@ def __init__( """ Initialize the MultiphenoDataset. - Args: - - data (Dict[str, Dict]): Underlying dataframe from which data is structured into batches. - - min_variant_count (int): Minimum number of variants available for each gene. - - batch_size (int): Number of samples/individuals available in one batch. - - split (str, optional): Contains a prefix indicating the dataset the model operates on. Defaults to "train". - - cache_tensors (bool, optional): Indicates if samples have been pre-loaded or need to be extracted from zarr. + :param data: Underlying dataframe from which data is structured into batches. + :type data: Dict[str, Dict] + :param min_variant_count: Minimum number of variants available for each gene. + :type min_variant_count: int + :param batch_size: Number of samples/individuals available in one batch. + :type batch_size: int + :param split: Contains a prefix indicating the dataset the model operates on. Defaults to "train". (optional) + :type split: str + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. (optional) + :type cache_tensors: bool """ super().__init__() @@ -360,7 +373,7 @@ def __getitem__(self, index): def subset_samples(self): """ Function used to sort out samples which contain real phenotypes with NaN values and - samples with less variants each gene then what we specify as a minimum. + samples with fewer variants for each gene than the specified minimum. """ for pheno, pheno_data in self.data.items(): # First sum over annotations (dim 2) for each variant in each gene. @@ -402,15 +415,22 @@ def __init__( """ Initialize the MultiphenoBaggingData. - Args: - - data (Dict[str, Dict]): Underlying dataframe from which data structured into batches. - - train_proportion (float): Percentage by which data is divided into training/validation split. - - sample_with_replacement (bool, optional): If True, a sample can be selected multiple times in one epoch. Defaults to True. - - min_variant_count (int, optional): Minimum number of variants available for each gene. Defaults to 1. - - upsampling_factor (int, optional): Percentual factor by which to upsample data; >= 1. Defaults to 1. - - batch_size (Optional[int], optional): Number of samples/individuals available in one batch. Defaults to None. - - num_workers (Optional[int], optional): Number of workers simultaneously putting data into RAM. Defaults to 0. - - cache_tensors (bool, optional): Indicates if samples have been pre-loaded or need to be extracted from zarr. Defaults to False. + :param data: Underlying dataframe from which data structured into batches. + :type data: Dict[str, Dict] + :param train_proportion: Percentage by which data is divided into training/validation split. + :type train_proportion: float + :param sample_with_replacement: If True, a sample can be selected multiple times in one epoch. Defaults to True. (optional) + :type sample_with_replacement: bool + :param min_variant_count: Minimum number of variants available for each gene. Defaults to 1. (optional) + :type min_variant_count: int + :param upsampling_factor: Percentual factor by which to upsample data; >= 1. Defaults to 1. (optional) + :type upsampling_factor: int + :param batch_size: Number of samples/individuals available in one batch. Defaults to None. (optional) + :type batch_size: Optional[int] + :param num_workers: Number of workers simultaneously putting data into RAM. Defaults to 0. (optional) + :type num_workers: Optional[int] + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. Defaults to False. (optional) + :type cache_tensors: bool """ logger.info("Intializing datamodule") @@ -562,17 +582,23 @@ def run_bagging( """ Main function called during training. Also used for trial pruning and sampling new parameters in optuna. - Args: - - config (Dict): Dictionary containing configuration parameters, build from YAML file - - data (Dict[str, Dict]): Dict of phenotypes, each containing a dict storing the underlying data. - - log_dir (str): Path to where logs are written. - - checkpoint_file (Optional[str]): Path to where the weights of the trained model should be saved. - - trial (Optional[optuna.trial.Trial]): Optuna object generated from the study. - - trial_id (Optional[int]): Current trial in range n_trials. - - debug (bool): Use a strongly reduced dataframe - - Returns: - - Optional[float]: computes the lowest scores of all loss metrics and returns their average + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param data: Dict of phenotypes, each containing a dict storing the underlying data. + :type data: Dict[str, Dict] + :param log_dir: Path to where logs are written. + :type log_dir: str + :param checkpoint_file: Path to where the weights of the trained model should be saved. (optional) + :type checkpoint_file: Optional[str] + :param trial: Optuna object generated from the study. (optional) + :type trial: Optional[optuna.trial.Trial] + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param debug: Use a strongly reduced dataframe + :type debug: bool + + :returns: Optional[float]: computes the lowest scores of all loss metrics and returns their average + :rtype: Optional[float] """ # if hyperparameter optimization is performed (train(); hpopt_file != None) @@ -586,7 +612,7 @@ def run_bagging( # # phi_hidden_dim: 20 # hparam: - # type: int + # type : int # args: # - 16 # - 64 @@ -595,7 +621,7 @@ def run_bagging( # # this line should be translated into: # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) - # and afterwards replace the respective area in config to set the suggestion. + # and afterward replace the respective area in config to set the suggestion. config["model"]["config"] = suggest_hparams(config["model"]["config"], trial) logger.info("Model hyperparameters this trial:") pprint(config["model"]["config"]) @@ -783,25 +809,32 @@ def train( hpopt_file: str, ): """ - Main function called during training. Also used for trial pruning and sampling new parameters in optuna - - Args: - - debug (bool): Use a strongly reduced dataframe - - training_gene_file (Optional[str]): Path to a pickle file specifying on which genes training should be executed. - - n_trials (int): Number of trials to be performed by the given setting. - - trial_id (Optional[int]): Current trial in range n_trials. - - sample_file (Optional[str]): Path to a pickle file specifying which samples should be considered during training. - - phenotype (Tuple[Tuple[str, str, str, str]]): Array of phenotypes, containing an array of paths where the underlying data is stored: + Main function called during training. Also used for trial pruning and sampling new parameters in Optuna. + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param training_gene_file: Path to a pickle file specifying on which genes training should be executed. (optional) + :type training_gene_file: Optional[str] + :param n_trials: Number of trials to be performed by the given setting. + :type n_trials: int + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param sample_file: Path to a pickle file specifying which samples should be considered during training. (optional) + :type sample_file: Optional[str] + :param phenotype: Array of phenotypes, containing an array of paths where the underlying data is stored: - str: Phenotype name - str: Annotated gene variants as zarr file - str: Covariates each sample as zarr file - str: Ground truth phenotypes as zarr file - - config_file (str): Path to a YAML file, which serves for configuration. - - log_dir (str): Path to where logs are stored. - - hpopt_file (str): Path to where a .db file should be created in which the results of hyperparameter optimization are stored. - - Raises: - - ValueError: If no phenotype option is specified. + :type phenotype: Tuple[Tuple[str, str, str, str]] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param hpopt_file: Path to where a .db file should be created in which the results of hyperparameter optimization are stored. + :type hpopt_file: str + + :raises ValueError: If no phenotype option is specified. """ @@ -928,15 +961,18 @@ def best_training_run( """ Function to extract the best trial from an Optuna study and handle associated model checkpoints and configurations. - Args: - - debug (bool): Use a strongly reduced dataframe - - log_dir (str): Path to where logs are stored. - - checkpoint_dir (str): Directory where checkpoints have been stored. - - hpopt_db (str): Path to the database file containing the Optuna study results. - - config_file_out (str): store a reduced - - Returns: - - None + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param checkpoint_dir: Directory where checkpoints have been stored. + :type checkpoint_dir: str + :param hpopt_db: Path to the database file containing the Optuna study results. + :type hpopt_db: str + :param config_file_out: Path to store a reduced configuration file. + :type config_file_out: str + + :returns: None """ study = optuna.load_study( study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}" From 9d15726c1e53362720c9b686822dcccbabdac2f7 Mon Sep 17 00:00:00 2001 From: Magnus Wahlberg Date: Thu, 30 Nov 2023 10:04:55 +0100 Subject: [PATCH 4/4] Black code --- deeprvat/deeprvat/train.py | 39 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 367e4576..3b414b9b 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -229,8 +229,7 @@ def make_dataset( :returns: None """ - - + with open(config_file) as f: config = yaml.safe_load(f) @@ -258,6 +257,7 @@ class MultiphenoDataset(Dataset): class used to structure the data and present a __getitem__ function to the dataloader, that will be used to load batches into the model """ + def __init__( self, # input_tensor: zarr.core.Array, @@ -317,7 +317,7 @@ def __init__( self.total_samples = sum([s.shape[0] for s in self.samples.values()]) self.batch_size = batch_size - # index all samples and categorize them by phenotype, such that we + # index all samples and categorize them by phenotype, such that we # get a dataframe repreenting a chain of phenotypes self.sample_order = pd.DataFrame( { @@ -368,7 +368,6 @@ def __getitem__(self, index): } return result - def subset_samples(self): """ @@ -401,6 +400,7 @@ class MultiphenoBaggingData(pl.LightningDataModule): """ Preprocess the underlying dataframe, to then load it into a dataset object """ + def __init__( self, data: Dict[str, Dict], @@ -484,7 +484,7 @@ def __init__( ) ) # samples which are not part of train_samples, but in samples - # are validation samples. + # are validation samples. pheno_data["samples"] = { "train": train_samples, "val": np.setdiff1d(samples, train_samples), @@ -497,10 +497,10 @@ def __init__( "num_workers", "cache_tensors", ) - + def upsample(self) -> np.ndarray: """ - does not work at the moment for multi-phenotype training. Needs some minor changes + does not work at the moment for multi-phenotype training. Needs some minor changes to make it work again """ unique_values = self.y.unique() @@ -525,11 +525,10 @@ def upsample(self) -> np.ndarray: self.samples = upsampled_indices - def train_dataloader(self): """ - trainning samples have been selected, but to structure them and make them load - as a batch they are packed in a dataset class, which is then wrapped by a + trainning samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a dataloading object. """ logger.info( @@ -546,12 +545,11 @@ def train_dataloader(self): return DataLoader( dataset, batch_size=None, num_workers=self.hparams.num_workers ) - def val_dataloader(self): """ - validation samples have been selected, but to structure them and make them load - as a batch they are packed in a dataset class, which is then wrapped by a + validation samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a dataloading object. """ logger.info( @@ -600,7 +598,7 @@ def run_bagging( :returns: Optional[float]: computes the lowest scores of all loss metrics and returns their average :rtype: Optional[float] """ - + # if hyperparameter optimization is performed (train(); hpopt_file != None) if trial is not None: if trial_id is not None: @@ -618,7 +616,7 @@ def run_bagging( # - 64 # kwargs: # step: 16 - # + # # this line should be translated into: # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) # and afterward replace the respective area in config to set the suggestion. @@ -631,7 +629,7 @@ def run_bagging( with open(config_out, "w") as f: yaml.dump(config, f) - # in practice we only train a single bag, as there are + # in practice we only train a single bag, as there are # theoretical reasons to omit bagging w.r.t. association testing n_bags = config["training"]["n_bags"] if not debug else 3 train_proportion = config["training"].get("train_proportion", None) @@ -669,7 +667,7 @@ def run_bagging( **dm_kwargs, **config["training"]["dataloader_config"], ) - + # setup the model architecture as specified in config model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class( @@ -688,7 +686,7 @@ def run_bagging( objective = "val_" + config["model"]["config"]["metrics"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] - + # to prune underperforming trials we enable a pruning strategy that can be set in config if "early_stopping" in config: callbacks.append( @@ -706,7 +704,7 @@ def run_bagging( while True: try: - # actual training of the model + # actual training of the model trainer.fit(model, dm) except RuntimeError as e: # if batch_size is choosen to big, it will be reduced until it fits the GPU @@ -836,8 +834,7 @@ def train( :raises ValueError: If no phenotype option is specified. """ - - + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified")