diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 7f1739bb..6fc189d4 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -524,7 +524,9 @@ class LinearAgg(pl.LightningModule): It still contains the gene impairment module used for burden computation. """ - def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): + def __init__( + self, n_annotations: int, pool: str, output_dim: int = 1, reverse: bool = False + ): """ Initialize the LinearAgg model. @@ -542,6 +544,20 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): input_dim = n_annotations self.linear = nn.Linear(n_annotations, self.output_dim) + self.reverse = reverse + + def set_reverse(self, reverse: bool = True): + """ + Reverse burden score during association testing if the model predicts in negative space. + + :param reverse: Indicates whether the 'reverse' attribute should be set to True or False. + Defaults to True. + :type reverse: bool + + Note: + Compare associate.py, reverse_models() for further detail + """ + self.reverse = reverse def forward(self, x): """ @@ -561,6 +577,8 @@ def forward(self, x): else: x = torch.max(x, dim=2).values # Now x.shape = samples x genes x output_dim + if self.reverse: + x = -x return x @@ -617,7 +635,14 @@ def __init__( for param in self.agg_model.parameters(): param.requires_grad = True - self.gene_pheno = nn.Linear(self.hparams.n_covariates + self.hparams.n_genes, 1) + self.gene_pheno = nn.ModuleDict( + { + pheno: nn.Linear( + self.hparams.n_covariates + self.hparams.n_genes[pheno], 1 + ) + for pheno in self.hparams.phenotypes + } + ) def forward(self, batch): """ @@ -633,9 +658,14 @@ def forward(self, batch): :returns: Dictionary containing predicted phenotypes :rtype: dict """ - # samples x genes x annotations x variants - x = batch["rare_variant_annotations"] - x = self.agg_model(x).squeeze(dim=2) # samples x genes - x = torch.cat((batch["covariates"], x), dim=1) - x = self.gene_pheno(x).squeeze(dim=1) # samples - return x + result = dict() + for pheno, this_batch in batch.items(): + x = this_batch["rare_variant_annotations"] + # x.shape = samples x genes x annotations x variants + x = self.agg_model(x).squeeze(dim=2) + # x.shape = samples x genes + x = torch.cat((this_batch["covariates"], x), dim=1) + # x.shape = samples x (genes + covariates) + result[pheno] = self.gene_pheno[pheno](x).squeeze(dim=1) + # result[pheno].shape = samples + return result diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index ae9995a0..e538c0e1 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -1172,12 +1172,21 @@ def best_training_run( :returns: None """ + study = optuna.load_study( study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}" ) trials = study.trials_dataframe().query('state == "COMPLETE"') - best_trial = trials.sort_values("value", ascending=False).iloc[0] + with open("config.yaml") as f: + config = yaml.safe_load(f) + ascending = ( + False + if config["hyperparameter_optimization"]["direction"] == "maximize" + else True + ) + f.close() + best_trial = trials.sort_values("value", ascending=ascending).iloc[0] best_trial_id = best_trial["user_attrs_user_id"] logger.info(f"Best trial:\n{best_trial}")