diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 633bd63c..7f1739bb 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -43,6 +43,11 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any): class BaseModel(pl.LightningModule): + """ + Base class containing functions that will be called by PyTorch Lightning in the + background by default. + """ + def __init__( self, config: dict, @@ -53,6 +58,23 @@ def __init__( stage: str = "train", **kwargs, ): + """ + Initializes BaseModel. + + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param stage: Contains a prefix indicating the dataset the model is operating on. Defaults to "train". (optional) + :type stage: str + :param kwargs: Additional keyword arguments. + """ super().__init__() self.save_hyperparameters(config) self.save_hyperparameters(kwargs) @@ -75,6 +97,10 @@ def __init__( raise ValueError("Unknown objective_mode configuration parameter") def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Function used to setup an optimizer and scheduler by their + parameters which are specified in config + """ optimizer_config = self.hparams["optimizer"] optimizer_class = getattr(torch.optim, optimizer_config["type"]) optimizer = optimizer_class( @@ -100,9 +126,25 @@ def configure_optimizers(self) -> torch.optim.Optimizer: return optimizer def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + """ + Function called by trainer during training and returns the loss used + to update weights and biases. + + :param batch: A dictionary containing the batch data. + :type batch: dict + :param batch_idx: The index of the current batch. + :type batch_idx: int + + :returns: torch.Tensor: The loss value computed to update weights and biases + based on the predictions. + :raises RuntimeError: If NaNs are found in the training loss. + """ + # calls DeepSet.forward() y_pred_by_pheno = self(batch) results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): + # compute mean distance in between ground truth and predicted score. results[name] = torch.mean( torch.stack( [ @@ -112,22 +154,49 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: ) ) self.log(f"{self.hparams.stage}_{name}", results[name]) - + # set loss from which we compute backward passes loss = results[self.hparams.metrics["loss"]] if torch.any(torch.isnan(loss)): raise RuntimeError("NaNs found in training loss") return loss def validation_step(self, batch: dict, batch_idx: int): + """ + During validation, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. + + :param batch: A dictionary containing the validation batch data. + :type batch: dict + :param batch_idx: The index of the current validation batch. + :type batch_idx: int + + :returns: dict: A dictionary containing phenotype predictions ("y_pred_by_pheno") + and corresponding ground truth values ("y_by_pheno"). + """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] ): + """ + Evaluate accumulated phenotype predictions at the end of the validation epoch. + + This function takes a list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. It computes + various metrics based on these predictions and logs the results. + + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] + + :return: None + :rtype: None + """ y_pred_by_pheno = dict() y_by_pheno = dict() for result in prediction_y: + # create a dict for each phenotype that includes all respective predictions pred = result["y_pred_by_pheno"] for pheno, ys in pred.items(): y_pred_by_pheno[pheno] = torch.cat( @@ -138,14 +207,16 @@ def validation_epoch_end( ys, ] ) - + # create a dict for each phenotype that includes the respective ground truth target = result["y_by_pheno"] for pheno, ys in target.items(): y_by_pheno[pheno] = torch.cat( [y_by_pheno.get(pheno, torch.tensor([], device=self.device)), ys] ) + # create a dict for each phenotype that stores the respective loss results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): results[name] = torch.mean( torch.stack( @@ -156,15 +227,36 @@ def validation_epoch_end( ) ) self.log(f"val_{name}", results[name]) - + # consider all metrics only store the most min/max in self.best_objective + # to determine if progress was made in the last training epoch. self.best_objective = self.objective_operation( self.best_objective, results[self.hparams.metrics["objective"]].item() ) def test_step(self, batch: dict, batch_idx: int): + """ + During testing, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. + + :param batch: A dictionary containing the testing batch data. + :type batch: dict + :param batch_idx: The index of the current testing batch. + :type batch_idx: int + + :returns: dict: A dictionary containing phenotype predictions ("y_pred") + and corresponding ground truth values ("y"). + :rtype: dict + """ return {"y_pred": self(batch), "y": batch["y"]} def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): + """ + Evaluate accumulated phenotype predictions at the end of the testing epoch. + + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the testing process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] + """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -182,6 +274,15 @@ def configure_callbacks(self): class DeepSetAgg(pl.LightningModule): + """ + class contains the gene impairment module used for burden computation. + + Variants are fed through an embedding network Phi to compute a variant embedding. + The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. + Afterward, the second network Rho estimates the final gene impairment score. + All parameters of the gene impairment module are shared across genes and traits. + """ + def __init__( self, n_annotations: int, @@ -196,6 +297,32 @@ def __init__( use_sigmoid: bool = False, reverse: bool = False, ): + """ + Initializes the DeepSetAgg module. + + :param n_annotations: Number of annotations. + :type n_annotations: int + :param phi_layers: Number of layers in Phi. + :type phi_layers: int + :param phi_hidden_dim: Internal dimensionality of linear layers in Phi. + :type phi_hidden_dim: int + :param rho_layers: Number of layers in Rho. + :type rho_layers: int + :param rho_hidden_dim: Internal dimensionality of linear layers in Rho. + :type rho_hidden_dim: int + :param activation: Activation function used; should match its name in torch.nn. + :type activation: str + :param pool: Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'. + :type pool: str + :param output_dim: Number of burden scores. Defaults to 1. (optional) + :type output_dim: int + :param dropout: Probability by which some parameters are set to 0. (optional) + :type dropout: Optional[float] + :param use_sigmoid: Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False. (optional) + :type use_sigmoid: bool + :param reverse: Whether to reverse the burden score (used during association testing). Defaults to False. (optional) + :type reverse: bool + """ super().__init__() self.output_dim = output_dim @@ -205,6 +332,7 @@ def __init__( self.use_sigmoid = use_sigmoid self.reverse = reverse + # setup of Phi input_dim = n_annotations phi = [] for l in range(phi_layers): @@ -216,10 +344,12 @@ def __init__( input_dim = output_dim self.phi = nn.Sequential(OrderedDict(phi)) + # setup permutation-invariant aggregation function if pool not in ("sum", "max"): raise ValueError(f"Unknown pooling operation {pool}") self.pool = pool + # setup of Rho rho = [] for l in range(rho_layers - 1): output_dim = rho_hidden_dim @@ -231,12 +361,33 @@ def __init__( rho.append( (f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim)) ) + # No final non-linear activation function to keep the relationship between + # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) def set_reverse(self, reverse: bool = True): + """ + Reverse burden score during association testing if the model predicts in negative space. + + :param reverse: Indicates whether the 'reverse' attribute should be set to True or False. + Defaults to True. + :type reverse: bool + + Note: + Compare associate.py, reverse_models() for further detail + """ self.reverse = reverse def forward(self, x): + """ + Perform a forward pass through the model. + + :param x: Batched input data + :type x: tensor + + :returns: Burden scores + :rtype: tensor + """ x = self.phi(x.permute((0, 1, 3, 2))) # x.shape = samples x genes x variants x phi_latent if self.pool == "sum": @@ -245,7 +396,7 @@ def forward(self, x): x = torch.max(x, dim=2).values # Now x.shape = samples x genes x phi_latent x = self.rho(x) - # x.shape = samples x genes x rho_latent + # x.shape = samples x genes x 1 if self.reverse: x = -x if self.use_sigmoid: @@ -254,6 +405,13 @@ def forward(self, x): class DeepSet(BaseModel): + """ + Wrapper class for burden computation, that also does phenotype prediction. + It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions + like "training_step" or "validation_epoch_end" can be found. + Those functions are called in background by default. + """ + def __init__( self, config: dict, @@ -266,6 +424,28 @@ def __init__( reverse: bool = False, **kwargs, ): + """ + Initialize the DeepSet model. + + :param config: Containing the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[pl.LightningModule / nn.Module] + :param use_sigmoid: Determines if burden scores should be projected to [0, 1]. Acts as a linear activation + function to mimic association testing during training. + :type use_sigmoid: bool + :param reverse: Determines if the burden score should be reversed (used during association testing). + :type reverse: bool + :param kwargs: Additional keyword arguments. + """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs ) @@ -277,6 +457,9 @@ def __init__( pool = get_hparam(self, "pool", "sum") dropout = get_hparam(self, "dropout", None) + # self.agg_model compresses a batch + # from: samples x genes x annotations x variants + # to: samples x genes if agg_model is not None: self.agg_model = agg_model else: @@ -293,7 +476,11 @@ def __init__( reverse=reverse, ) self.agg_model.train(False if self.hparams.stage == "val" else True) + # afterwards genes are concatenated with covariates + # to: samples x (genes + covariates) + # dict of various linear layers used for phenotype prediction. + # Returns can be tested against ground truth data. self.gene_pheno = nn.ModuleDict( { pheno: nn.Linear( @@ -304,6 +491,19 @@ def __init__( ) def forward(self, batch): + """ + Forward pass through the model. + + :param batch: Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict + + :returns: Dictionary containing predicted phenotypes + :rtype: dict + """ result = dict() for pheno, this_batch in batch.items(): x = this_batch["rare_variant_annotations"] @@ -318,7 +518,23 @@ def forward(self, batch): class LinearAgg(pl.LightningModule): + """ + To capture only linear effect, this model can be used as it only uses a single + linear layer without a non-linear activation function. + It still contains the gene impairment module used for burden computation. + """ + def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): + """ + Initialize the LinearAgg model. + + :param n_annotations: Number of annotations. + :type n_annotations: int + :param pool: Pooling method ("sum" or "max") to be used. + :type pool: str + :param output_dim: Dimensionality of the output. Defaults to 1. (optional) + :type output_dim: int + """ super().__init__() self.output_dim = output_dim @@ -328,6 +544,15 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): self.linear = nn.Linear(n_annotations, self.output_dim) def forward(self, x): + """ + Perform a forward pass through the model. + + :param x: Batched input data + :type x: tensor + + :returns: Burden scores + :rtype: tensor + """ x = self.linear( x.permute((0, 1, 3, 2)) ) # x.shape = samples x genes x variants x output_dim @@ -340,6 +565,12 @@ def forward(self, x): class TwoLayer(BaseModel): + """ + Wrapper class to capture linear effects. Inherits parameters from BaseModel, + which is where Pytorch Lightning specific functions like "training_step" or + "validation_epoch_end" can be found. Those functions are called in background by default. + """ + def __init__( self, config: dict, @@ -349,6 +580,21 @@ def __init__( agg_model: Optional[nn.Module] = None, **kwargs, ): + """ + Initializes the TwoLayer model. + + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Number of annotations. + :type n_annotations: int + :param n_covariates: Number of covariates. + :type n_covariates: int + :param n_genes: Number of genes. + :type n_genes: int + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[nn.Module] + :param kwargs: Additional keyword arguments. + """ super().__init__(config, n_annotations, n_covariates, n_genes, **kwargs) logger.info("Initializing TwoLayer model with parameters:") @@ -374,6 +620,19 @@ def __init__( self.gene_pheno = nn.Linear(self.hparams.n_covariates + self.hparams.n_genes, 1) def forward(self, batch): + """ + Forward pass through the model. + + :param batch: Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict + + :returns: Dictionary containing predicted phenotypes + :rtype: dict + """ # samples x genes x annotations x variants x = batch["rare_variant_annotations"] x = self.agg_model(x).squeeze(dim=2) # samples x genes diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index c2e2abd4..3b414b9b 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -82,6 +82,22 @@ def make_dataset_( training_dataset_file: str = None, pickle_only: bool = False, ): + """ + Subfunction of make_dataset() + Convert a dataset file to the sparse format used for training and testing associations + + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param debug: Use a strongly reduced dataframe (optional) + :type debug: bool + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: str + :param pickle_only: If True, only store dataset as pickle file and return None. (optional) + :type pickle_only: bool + + :returns: Tuple containing input_tensor, covariates, and target values. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + """ n_phenotypes = config.get("n_phenotypes", None) if n_phenotypes is not None: if "seed_genes" in config: @@ -113,6 +129,7 @@ def make_dataset_( or training_dataset_file is None or not Path(training_dataset_file).is_file() ): + # load data into sparse data format ds = DenseGTDataset( gt_file=config["training_data"]["gt_file"], variant_file=config["training_data"]["variant_file"], @@ -190,6 +207,29 @@ def make_dataset( covariates_out_file: str, y_out_file: str, ): + """ + Uses function make_dataset_() to convert dataset to sparse format and stores the respective data + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param pickle_only: Flag to indicate whether only to save data using pickle + :type pickle_only: bool + :param compression_level: Level of compression in ZARR to be applied to training data. + :type compression_level: int + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: Optional[str] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param input_tensor_out_file: Path to save the training data to. + :type input_tensor_out_file: str + :param covariates_out_file: Path to save the covariates to. + :type covariates_out_file: str + :param y_out_file: Path to save the ground truth data to. + :type y_out_file: str + + :returns: None + """ + with open(config_file) as f: config = yaml.safe_load(f) @@ -213,6 +253,11 @@ def make_dataset( class MultiphenoDataset(Dataset): + """ + class used to structure the data and present a __getitem__ function to + the dataloader, that will be used to load batches into the model + """ + def __init__( self, # input_tensor: zarr.core.Array, @@ -226,7 +271,21 @@ def __init__( # samples: Optional[Union[slice, np.ndarray]] = None, # genes: Optional[Union[slice, np.ndarray]] = None ): - "Initialization" + """ + Initialize the MultiphenoDataset. + + :param data: Underlying dataframe from which data is structured into batches. + :type data: Dict[str, Dict] + :param min_variant_count: Minimum number of variants available for each gene. + :type min_variant_count: int + :param batch_size: Number of samples/individuals available in one batch. + :type batch_size: int + :param split: Contains a prefix indicating the dataset the model operates on. Defaults to "train". (optional) + :type split: str + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. (optional) + :type cache_tensors: bool + """ + super().__init__() self.data = data @@ -258,6 +317,8 @@ def __init__( self.total_samples = sum([s.shape[0] for s in self.samples.values()]) self.batch_size = batch_size + # index all samples and categorize them by phenotype, such that we + # get a dataframe repreenting a chain of phenotypes self.sample_order = pd.DataFrame( { "phenotype": itertools.chain( @@ -269,6 +330,7 @@ def __init__( {"phenotype": pd.api.types.CategoricalDtype()} ) self.sample_order = self.sample_order.sample(n=self.total_samples) # shuffle + # phenotype specific index; e.g. 7. element total, 2. element for phenotype "Urate" self.sample_order["index"] = self.sample_order.groupby("phenotype").cumcount() def __len__(self): @@ -289,6 +351,7 @@ def __getitem__(self, index): result = dict() for pheno, df in samples_by_pheno: + # get phenotype specific sub-index idx = df["index"].to_numpy() annotations = ( @@ -307,6 +370,10 @@ def __getitem__(self, index): return result def subset_samples(self): + """ + Function used to sort out samples which contain real phenotypes with NaN values and + samples with fewer variants for each gene than the specified minimum. + """ for pheno, pheno_data in self.data.items(): # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all @@ -330,6 +397,10 @@ def subset_samples(self): class MultiphenoBaggingData(pl.LightningDataModule): + """ + Preprocess the underlying dataframe, to then load it into a dataset object + """ + def __init__( self, data: Dict[str, Dict], @@ -341,6 +412,26 @@ def __init__( num_workers: Optional[int] = 0, cache_tensors: bool = False, ): + """ + Initialize the MultiphenoBaggingData. + + :param data: Underlying dataframe from which data structured into batches. + :type data: Dict[str, Dict] + :param train_proportion: Percentage by which data is divided into training/validation split. + :type train_proportion: float + :param sample_with_replacement: If True, a sample can be selected multiple times in one epoch. Defaults to True. (optional) + :type sample_with_replacement: bool + :param min_variant_count: Minimum number of variants available for each gene. Defaults to 1. (optional) + :type min_variant_count: int + :param upsampling_factor: Percentual factor by which to upsample data; >= 1. Defaults to 1. (optional) + :type upsampling_factor: int + :param batch_size: Number of samples/individuals available in one batch. Defaults to None. (optional) + :type batch_size: Optional[int] + :param num_workers: Number of workers simultaneously putting data into RAM. Defaults to 0. (optional) + :type num_workers: Optional[int] + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. Defaults to False. (optional) + :type cache_tensors: bool + """ logger.info("Intializing datamodule") super().__init__() @@ -386,11 +477,14 @@ def __init__( else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() + # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( samples, size=n_train_samples, replace=sample_with_replacement ) ) + # samples which are not part of train_samples, but in samples + # are validation samples. pheno_data["samples"] = { "train": train_samples, "val": np.setdiff1d(samples, train_samples), @@ -405,6 +499,10 @@ def __init__( ) def upsample(self) -> np.ndarray: + """ + does not work at the moment for multi-phenotype training. Needs some minor changes + to make it work again + """ unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -428,6 +526,11 @@ def upsample(self) -> np.ndarray: self.samples = upsampled_indices def train_dataloader(self): + """ + trainning samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating training dataloader " f"with batch size {self.hparams.batch_size}" @@ -444,6 +547,11 @@ def train_dataloader(self): ) def val_dataloader(self): + """ + validation samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating validation dataloader " f"with batch size {self.hparams.batch_size}" @@ -469,10 +577,49 @@ def run_bagging( trial_id: Optional[int] = None, debug: bool = False, ) -> Optional[float]: + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna. + + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param data: Dict of phenotypes, each containing a dict storing the underlying data. + :type data: Dict[str, Dict] + :param log_dir: Path to where logs are written. + :type log_dir: str + :param checkpoint_file: Path to where the weights of the trained model should be saved. (optional) + :type checkpoint_file: Optional[str] + :param trial: Optuna object generated from the study. (optional) + :type trial: Optional[optuna.trial.Trial] + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param debug: Use a strongly reduced dataframe + :type debug: bool + + :returns: Optional[float]: computes the lowest scores of all loss metrics and returns their average + :rtype: Optional[float] + """ + + # if hyperparameter optimization is performed (train(); hpopt_file != None) if trial is not None: if trial_id is not None: + # differentiate various repeats in their individual optimization trial.set_user_attr("user_id", trial_id) + # Parameters set in config can be used to indicate hyperparameter optimization. + # Such cases can be spotted by the following exemplary pattern: + # + # phi_hidden_dim: 20 + # hparam: + # type : int + # args: + # - 16 + # - 64 + # kwargs: + # step: 16 + # + # this line should be translated into: + # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) + # and afterward replace the respective area in config to set the suggestion. config["model"]["config"] = suggest_hparams(config["model"]["config"], trial) logger.info("Model hyperparameters this trial:") pprint(config["model"]["config"]) @@ -482,6 +629,8 @@ def run_bagging( with open(config_out, "w") as f: yaml.dump(config, f) + # in practice we only train a single bag, as there are + # theoretical reasons to omit bagging w.r.t. association testing n_bags = config["training"]["n_bags"] if not debug else 3 train_proportion = config["training"].get("train_proportion", None) logger.info(f"Training {n_bags} bagged models") @@ -511,6 +660,7 @@ def run_bagging( "cache_tensors", ) } + # load data into the required formate dm = MultiphenoBaggingData( this_data, train_proportion, @@ -518,6 +668,7 @@ def run_bagging( **config["training"]["dataloader_config"], ) + # setup the model architecture as specified in config model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class( config=config["model"]["config"], @@ -535,6 +686,8 @@ def run_bagging( objective = "val_" + config["model"]["config"]["metrics"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] + + # to prune underperforming trials we enable a pruning strategy that can be set in config if "early_stopping" in config: callbacks.append( EarlyStopping(monitor=objective, **config["early_stopping"]) @@ -544,14 +697,17 @@ def run_bagging( config["pl_trainer"]["min_epochs"] = 10 config["pl_trainer"]["max_epochs"] = 20 + # initialize trainer, which will call background functionality trainer = pl.Trainer( logger=tb_logger, callbacks=callbacks, **config.get("pl_trainer", {}) ) while True: try: + # actual training of the model trainer.fit(model, dm) except RuntimeError as e: + # if batch_size is choosen to big, it will be reduced until it fits the GPU logging.error(f"Caught RuntimeError: {e}") if str(e).find("CUDA out of memory") != -1: if dm.hparams.batch_size > 4: @@ -650,6 +806,35 @@ def train( log_dir: str, hpopt_file: str, ): + """ + Main function called during training. Also used for trial pruning and sampling new parameters in Optuna. + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param training_gene_file: Path to a pickle file specifying on which genes training should be executed. (optional) + :type training_gene_file: Optional[str] + :param n_trials: Number of trials to be performed by the given setting. + :type n_trials: int + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param sample_file: Path to a pickle file specifying which samples should be considered during training. (optional) + :type sample_file: Optional[str] + :param phenotype: Array of phenotypes, containing an array of paths where the underlying data is stored: + - str: Phenotype name + - str: Annotated gene variants as zarr file + - str: Covariates each sample as zarr file + - str: Ground truth phenotypes as zarr file + :type phenotype: Tuple[Tuple[str, str, str, str]] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param hpopt_file: Path to where a .db file should be created in which the results of hyperparameter optimization are stored. + :type hpopt_file: str + + :raises ValueError: If no phenotype option is specified. + """ + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified") @@ -677,6 +862,7 @@ def train( samples = slice(None) data = dict() + # pack underlying data into a single dict that can be passed to downstream functions for pheno, input_tensor_file, covariates_file, y_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open(input_tensor_file, mode="r") @@ -769,6 +955,22 @@ def train( def best_training_run( debug: bool, log_dir: str, checkpoint_dir: str, hpopt_db: str, config_file_out: str ): + """ + Function to extract the best trial from an Optuna study and handle associated model checkpoints and configurations. + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param checkpoint_dir: Directory where checkpoints have been stored. + :type checkpoint_dir: str + :param hpopt_db: Path to the database file containing the Optuna study results. + :type hpopt_db: str + :param config_file_out: Path to store a reduced configuration file. + :type config_file_out: str + + :returns: None + """ study = optuna.load_study( study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}" ) @@ -792,6 +994,7 @@ def best_training_run( link_path.symlink_to(checkpoint.resolve(strict=True)) # Keep track of models marked to be dropped + # respective models are not used for downstream processing checkpoint_dropped = Path(str(checkpoint) + ".dropped") if checkpoint_dropped.is_file(): dropped_link_path = Path(checkpoint_dir) / f"bag_{k}.ckpt.dropped"