diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 6f5811e4..0a660522 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -9,7 +9,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelSummary -from deeprvat.utils import init_model +from deeprvat.utils import init_model, inplace_prox, prox from deeprvat.deeprvat.submodules import Pooling, Layers from deeprvat.metrics import ( PearsonCorr, @@ -19,6 +19,8 @@ QuantileLoss, KLDIVLoss, BCELoss, + LassoLossTrain, + LassoLossVal, ) logging.basicConfig( @@ -40,6 +42,8 @@ "QuantileLoss": QuantileLoss, "KLDiv": KLDIVLoss, "BCELoss": BCELoss, + "LassoLossTrain": LassoLossTrain, + "LassoLossVal": LassoLossVal, } NORMALIZATION = { @@ -78,10 +82,12 @@ def __init__(self, "gene_count", "max_n_variants", "phenotypes", "stage") - self.metric_fns = {name: METRICS[name]() - for name in self.hparams.metrics["all"]} + self.metric_fns_train = {name: METRICS[name]() + for name in self.hparams.metrics_train["all"]} + self.metric_fns_val = {name: METRICS[name]() + for name in self.hparams.metrics_val["all"]} - self.objective_mode = self.hparams.metrics.get("objective_mode", "min") + self.objective_mode = self.hparams.metrics_train.get("objective_mode", "min") if self.objective_mode == "max": self.best_objective = float("-inf") self.objective_operation = max @@ -114,13 +120,13 @@ def configure_optimizers(self) -> torch.optim.Optimizer: def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: y_pred_by_pheno = self(batch) results = dict() - for name, fn in self.metric_fns.items(): + for name, fn in self.metric_fns_train.items(): results[name] = torch.mean( torch.stack([fn(y_pred, batch[pheno]["y"]) for pheno, y_pred in y_pred_by_pheno.items()])) self.log(f"{self.hparams.stage}_{name}", results[name]) - loss = results[self.hparams.metrics["loss"]] + loss = results[self.hparams.metrics_train["loss"]] if torch.any(torch.isnan(loss)): raise RuntimeError("NaNs found in training loss") return loss @@ -146,7 +152,7 @@ def validation_epoch_end(self, prediction_y: List[Dict[str, Dict[str, torch.Tens y_by_pheno.get(pheno, torch.tensor([],device=self.device)), ys]) results = dict() - for name, fn in self.metric_fns.items(): + for name, fn in self.metric_fns_val.items(): results[name] = torch.mean( torch.stack([ fn(y_pred, y_by_pheno[pheno]) @@ -154,7 +160,7 @@ def validation_epoch_end(self, prediction_y: List[Dict[str, Dict[str, torch.Tens self.log(f"val_{name}", results[name]) self.best_objective = self.objective_operation( - self.best_objective, results[self.hparams.metrics["objective"]].item()) + self.best_objective, results[self.hparams.metrics_val["objective"]].item()) def test_step(self, batch: dict, batch_idx: int): return {"y_pred": self(batch), "y": batch["y"]} @@ -164,12 +170,12 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): y = torch.cat([p["y"] for p in prediction_y]) results = {} - for name, fn in self.metric_fns.items(): + for name, fn in self.metric_fns_val.items(): results[name] = fn(y_pred, y) self.log(f"val_{name}", results[name]) self.best_objective = self.objective_operation(self.best_objective, - results[self.hparams.metrics["objective"]].item()) + results[self.hparams.metrics_val["objective"]].item()) def configure_callbacks(self): return [ModelSummary()] @@ -344,3 +350,395 @@ def forward(self, batch): pheno, this_batch["gene_id"]) return result + +class BaseLassoModel(pl.LightningModule): + def __init__(self, + config: dict, + n_annotations: Dict[str, int], + n_covariates: Dict[str, int], + n_genes: Dict[str, int], + gene_count: int, + max_n_variants: int, + lambda_: int, + gamma: float, + gamma_skip: float, + M:float, + phenotypes: List[str], + stage: str = "train", + **kwargs): + """ + Parameters + ---------- + gamma : float, default=0.0 + l2 penalization on the network + gamma_skip : float, default=0.0 + l2 penalization on the skip connection + M : float, default=10.0 + Hierarchy parameter + """ + super().__init__() + + self.automatic_optimization = False ############## + + self.save_hyperparameters(config) + self.save_hyperparameters(kwargs) + self.save_hyperparameters("n_annotations", "n_covariates", "n_genes", + "gene_count", "max_n_variants", + "lambda_", "gamma", "gamma_skip", + "M", "phenotypes", "stage") + + self.metric_fns_train = {name: METRICS[name]() + for name in self.hparams.metrics_train["all"]} + self.metric_fns_val = {name: METRICS[name]() + for name in self.hparams.metrics_val["all"]} + + self.objective_mode = self.hparams.metrics_train.get("objective_mode", "min") + if self.objective_mode == "max": + self.best_objective = float("-inf") + self.objective_operation = max + elif self.objective_mode == "min": + self.best_objective = float("inf") + self.objective_operation = min + else: + raise ValueError("Unknown objective_mode configuration parameter") + + def configure_optimizers(self) -> torch.optim.Optimizer: + optimizer_config = self.hparams["optimizer"] + optimizer_class = getattr(torch.optim, optimizer_config["type"]) + optimizer = optimizer_class(self.parameters(), + **optimizer_config.get("config", {})) + + lrsched_config = optimizer_config.get("lr_scheduler", None) + if lrsched_config is not None: + lr_scheduler_class = getattr(torch.optim.lr_scheduler, + lrsched_config["type"]) + lr_scheduler = lr_scheduler_class(optimizer, + **lrsched_config["config"]) + + if lrsched_config["type"] == "ReduceLROnPlateau": + return {"optimizer": optimizer, + "lr_scheduler": lr_scheduler, + "monitor": lrsched_config["monitor"]} + else: return [optimizer], [lr_scheduler] + else: return optimizer + + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + + self.log(f"Lambda = ", torch.FloatTensor([self.hparams['lambda_']])) + self.log(f"selected features = ", torch.FloatTensor([self.selected_count()])) + + opt = self.optimizers() + + y_pred_by_pheno = self(batch) + results = dict() + for name, fn in self.metric_fns_train.items(): + if name == "LassoLossTrain": + results[name] = torch.mean( + torch.stack([ (fn(y_pred, batch[pheno]["y"], + self.hparams['gamma'], + self.hparams['gamma_skip'], + self.l2_regularization())) + for pheno, y_pred in y_pred_by_pheno.items()])) + else: + results[name] = torch.mean( + torch.stack([ (fn(y_pred, batch[pheno]["y"])) + for pheno, y_pred in y_pred_by_pheno.items()])) + + self.log(f"{self.hparams.stage}_{name}", results[name]) + + loss = results[self.hparams.metrics_train["loss"]] + if torch.any(torch.isnan(loss)): + raise RuntimeError("NaNs found in training loss") + + opt.zero_grad() + self.manual_backward(loss) + opt.step() + self.prox(lambda_=self.hparams['lambda_'] * opt.param_groups[0]["lr"], M=self.hparams['M']) + + return loss + + def validation_step(self, batch: dict, batch_idx: int): + y_by_pheno = {pheno: pheno_batch["y"] + for pheno, pheno_batch in batch.items()} + return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} + + def validation_epoch_end(self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]]): + y_pred_by_pheno = dict() + y_by_pheno = dict() + for result in prediction_y: + pred = result["y_pred_by_pheno"] + for pheno, ys in pred.items(): + y_pred_by_pheno[pheno] = torch.cat([ + y_pred_by_pheno.get(pheno, + torch.tensor([], device=self.device)), ys]) + + target = result["y_by_pheno"] + for pheno, ys in target.items(): + y_by_pheno[pheno] = torch.cat([ + y_by_pheno.get(pheno, torch.tensor([],device=self.device)), ys]) + + results = dict() + for name, fn in self.metric_fns_val.items(): + if name == "LassoLossVal": + results[name] = torch.mean( + torch.stack([ + (fn(y_pred, y_by_pheno[pheno], + self.hparams['lambda_'], + self.hparams['gamma'], + self.hparams['gamma_skip'], + self.l1_regularization_skip().item(), + self.l2_regularization())) + for pheno, y_pred in y_pred_by_pheno.items()])) + else: + results[name] = torch.mean( + torch.stack([ + (fn(y_pred, y_by_pheno[pheno])) + for pheno, y_pred in y_pred_by_pheno.items()])) + + self.log(f"val_{name}", results[name]) + + self.best_objective = self.objective_operation( + self.best_objective, results[self.hparams.metrics_val["objective"]].item()) + + def test_step(self, batch: dict, batch_idx: int): + return {"y_pred": self(batch), "y": batch["y"]} + + def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): + y_pred = torch.cat([p["y_pred"] for p in prediction_y]) + y = torch.cat([p["y"] for p in prediction_y]) + + results = {} + for name, fn in self.metric_fns_val.items(): + if name == "LassoLossVal": + results[name] = (fn(y_pred, y, + self.hparams['lambda_'], + self.hparams['gamma'], + self.hparams['gamma_skip'], + self.l1_regularization_skip().item(), + self.l2_regularization())) + else: + results[name] = (fn(y_pred, y)) + self.log(f"val_{name}", results[name]) + + self.best_objective = self.objective_operation(self.best_objective, + results[self.hparams.metrics_val["objective"]].item()) + + def configure_callbacks(self): + return [ModelSummary()] + +class DeepSetLassoAgg(pl.LightningModule): + def __init__( + self, + deep_rvat: int, + pool_layer: str, + skip: int, + use_sigmoid: bool = False, + reverse: bool = False, + ): + super().__init__() + + self.deep_rvat = deep_rvat + self.pool_layer = pool_layer + self.skip = skip + self.use_sigmoid = use_sigmoid + self.reverse = reverse + + def set_reverse(self, reverse: bool = True): + self.reverse = reverse + + def forward(self, x): + x_base = x.permute((0,1,3,2)) + + x = x_base #x.permute((0, 1, 3, 2)) + # x.shape = samples x genes x variants x annotations + x = self.deep_rvat(x) + # x.shape = samples x genes + x = x + self.skip(torch.max(x_base, dim=2).values) + + if self.reverse: x = -x + if self.use_sigmoid: x = torch.sigmoid(x) + # burden_score + return x + +class DeepSetLasso(BaseLassoModel): + def __init__( + self, + config: dict, + n_annotations: Dict[str, int], + n_covariates: Dict[str, int], + n_genes: Dict[str, int], + gene_count: int, + max_n_variants: int, + lambda_: float, + gamma: float, + gamma_skip: float, + M: float, + phenotypes: List[str], + agg_model: Optional[nn.Module] = None, + **kwargs): + """ + Adapted functions from LassoNet: + Lemhadri, I., Ruan, F., Abraham, L., & Tibshirani, R. (2021). + Lassonet: A neural network with feature sparsity. + The Journal of Machine Learning Research, 22(1), 5633-5661. + https://github.com/lasso-net/lassonet/tree/master + + """ + super().__init__( + config, + n_annotations, + n_covariates, + n_genes, + gene_count, + max_n_variants, + lambda_, + gamma, + gamma_skip, + M, + phenotypes, + **kwargs) + + logger.info("Initializing DeepSet model with parameters:") + pprint(self.hparams) + + self.normalization = getattr(self.hparams, "normalization", False) + self.activation = getattr(nn, getattr(self.hparams, "activation", "LeakyReLU"))() + self.use_sigmoid = getattr(self.hparams, "use_sigmoid", False) + self.reverse = getattr(self.hparams, "reverse", False) + self.pool_layer = getattr(self.hparams, "pool", "sum") + self.init_power_two = getattr(self.hparams, "first_layer_nearest_power_two", False) + self.steady_dim = getattr(self.hparams, "steady_dim", False) + + self.phi = self.get_model("phi", + n_annotations, + self.hparams.phi_hidden_dim, + self.hparams.phi_layers, + self.hparams.phi_res_layers) + + self.pool = Pooling(self.normalization, self.pool_layer, self.hparams.phi_hidden_dim, max_n_variants) + self.rho = self.get_model("rho", + self.hparams.phi_hidden_dim, + self.hparams.rho_hidden_dim, + self.hparams.rho_layers - 1, + self.hparams.rho_res_layers) + self.gene_pheno = Phenotype_classifier(self.hparams, phenotypes, n_genes, gene_count) + + self.deep_rvat = lambda x : self.rho(self.pool(self.phi(x))) + + self.skip = nn.Linear(n_annotations, 1, bias=False) + + if agg_model is not None: + self.agg_model = agg_model + else: + self.agg_model = DeepSetLassoAgg( + deep_rvat=self.deep_rvat, + pool_layer=self.pool_layer, + skip=self.skip, + use_sigmoid=self.use_sigmoid, + reverse=self.reverse + ) + self.agg_model.train(False if self.hparams.stage == "val" else True) + + self.train(False if self.hparams.stage == "val" else True) + + def get_model(self, prefix, input_dim, output_dim, n_layers, res_layers): + model = [] + Layers_obj = Layers(n_layers, res_layers, input_dim, output_dim, self.activation, self.normalization, self.init_power_two, self.steady_dim) + for l in range(n_layers): + model.append((f"{prefix}_layer_{l}", Layers_obj.get_layer(l))) + model.append((f"{prefix}_activation_{l}", self.activation)) + if prefix == "rho": model.append((f"{prefix}_linear_{n_layers}", nn.Linear(output_dim, 1))) + model = nn.Sequential(OrderedDict(model)) + model = init_params(self.hparams, model) + return model + + def forward(self, batch): + result = dict() + for pheno, this_batch in batch.items(): + x = this_batch["rare_variant_annotations"] + # x.shape = samples x genes x annotations x variants + burden_score = self.agg_model(x).squeeze(dim=2) + result[pheno] = self.gene_pheno.forward(burden_score, + this_batch["covariates"], + pheno, + this_batch["gene_id"]) + return result + + def prox(self, *, lambda_, lambda_bar=0, M=1): + #self.groups is None: + with torch.no_grad(): + inplace_prox( + beta=self.skip, + theta=self.phi.phi_layer_0.layer, + lambda_=lambda_, + lambda_bar=lambda_bar, + M=M, + ) + + def lambda_start( + self, + M=1, + lambda_bar=0, + factor=2, + ): + """Estimate when the model will start to sparsify.""" + def is_sparse(lambda_): + with torch.no_grad(): + beta = self.skip.weight.data + theta = self.phi.phi_layer_0.layer.weight.data + + for _ in range(10000): + new_beta, theta = prox( + beta, + theta, + lambda_=lambda_, + lambda_bar=lambda_bar, + M=M, + ) + if torch.abs(beta - new_beta).max() < 1e-5: + break + beta = new_beta + return (torch.norm(beta, p=2, dim=0) == 0).sum() + + start = 1e-6 + while not is_sparse(factor * start): + start *= factor + return start + + def l2_regularization(self): + """ + L2 regulatization of the MLPs in phi & rho without the first layer + which is bounded by the skip connection + """ + ans = 0 + for count, (name, param) in enumerate(self.phi.named_parameters()): + if 'weight' in name : + if count != 0: + # print(name.rstrip('.layer.weight')) + layer_obj = getattr(self.phi, name.rstrip('.layer.weight')) + ans += (torch.norm(layer_obj.layer.weight.data, p=2)** 2) + + for count, (name, param) in enumerate(self.rho.named_parameters()): + if 'weight' in name : + if 'linear' in name : #for last linear layer of rho + # print(name.rstrip('.layer.weight')) + layer_obj = getattr(self.rho, name.rstrip('.weight')) + ans += (torch.norm(layer_obj.weight.data, p=2)** 2) + else: + # print(name.rstrip('.layer.weight')) + layer_obj = getattr(self.rho, name.rstrip('.layer.weight')) + ans += (torch.norm(layer_obj.layer.weight.data, p=2)** 2) + return ans + + def l1_regularization_skip(self): + return torch.norm(self.skip.weight.data, p=2, dim=0).sum() + + def l2_regularization_skip(self): + return torch.norm(self.skip.weight.data, p=2) + + def input_mask(self): + with torch.no_grad(): + return torch.norm(self.skip.weight.data, p=2, dim=0) != 0 + + def selected_count(self): + return self.input_mask().sum().item() \ No newline at end of file diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index f31b3292..de3e7cbc 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -38,6 +38,8 @@ QuantileLoss, KLDIVLoss, BCELoss, + LassoLossTrain, + LassoLossVal, ) from deeprvat.utils import suggest_hparams @@ -62,6 +64,8 @@ "QuantileLoss": QuantileLoss, "KLDiv": KLDIVLoss, "BCELoss": BCELoss, + "LassoLossTrain": LassoLossTrain, + "LassoLossVal": LassoLossVal, } OPTIMIZERS = { "sgd": optim.SGD, @@ -500,9 +504,9 @@ def run_bagging( set_random_seed(config['model']["seed"]) logger.info(f'Set seed from config file') - results = [] - checkpoint_paths = [] - for k in range(n_bags): + lasso = config["model"].get("run_lasso", False) + + def run_training(k): logger.info(f" Starting training for bag {k}") this_data = copy.deepcopy(data) @@ -549,7 +553,7 @@ def run_bagging( logger.info(f" Writing TensorBoard logs to {tb_log_dir}") tb_logger = TensorBoardLogger(log_dir, name=f"bag_{k}") - objective = "val_" + config["model"]["config"]["metrics"]["objective"] + objective = "val_" + config["model"]["config"]["metrics_val"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] if "early_stopping" in config: @@ -613,10 +617,204 @@ def run_bagging( gc.collect() torch.cuda.empty_cache() + def run_lassotraining(k,current_lambda): + nonlocal counter + nonlocal is_dense + logger.info(f" Starting lasso training for lambda {current_lambda}") + logger.info(f" Starting lasso training for bag {k}") + + this_data = copy.deepcopy(data) + for _, pheno_data in this_data.items(): + if pheno_data["training_genes"] is not None: + pheno_data["genes"] = pheno_data["training_genes"][f"bag_{k}"] + logger.info( + f'Using {len(pheno_data["genes"])} training genes ' + f'(out of {pheno_data["input_tensor_zarr"].shape[1]} total) at indices:' + ) + print(" ".join(map(str, pheno_data["genes"]))) + + dm_kwargs = { + k: v + for k, v in config["training"].items() + if k + in ( + "min_variant_count", + "upsampling_factor", + "sample_with_replacement", + "cache_tensors", + ) + } + dm = MultiphenoBaggingData( + this_data, + train_proportion, + **dm_kwargs, + **config["training"]["dataloader_config"], + ) + + model_class = getattr(deeprvat_models, config["model"]["type"]) + + if counter == 0: + model = model_class( + config=config["model"]["config"], + n_annotations=dm.n_annotations, + n_covariates=dm.n_covariates, + n_genes=dm.n_genes, + gene_count=gene_count, + lambda_= current_lambda, #config['model']["config"]['lambda'], + gamma=0.0, + gamma_skip=0.0, + M=10.0, + max_n_variants=dm.max_n_variants, + phenotypes=list(data.keys()), + **config["model"].get("kwargs", {}), + ) + else: + logger.info(f"Using previous checkpoint {checkpoint_paths[counter-1]} state to resume training with new Lambda {current_lambda}") + checkpoint = checkpoint_paths[counter-1] + model = model_class.load_from_checkpoint( + checkpoint, + config=config["model"]["config"], + lambda_= current_lambda, + ) + + tb_log_dir = f"{log_dir}/bag_{k}/lambda_{current_lambda}" + logger.info(f" Writing TensorBoard logs to {tb_log_dir}") + tb_logger = TensorBoardLogger(log_dir, name=f"bag_{k}/lambda_{current_lambda}") + + objective = "val_" + config["model"]["config"]["metrics_val"]["objective"] + checkpoint_callback = ModelCheckpoint(monitor=objective, + save_on_train_epoch_end=True) + callbacks = [checkpoint_callback] + if "early_stopping" in config: + callbacks.append( + EarlyStopping(monitor=objective, **config["early_stopping"]) + ) + + if debug: + config["pl_trainer"]["min_epochs"] = 10 + config["pl_trainer"]["max_epochs"] = 20 + + + trainer = pl.Trainer( + logger=tb_logger, callbacks=callbacks, **config.get("pl_trainer", {}) + ) + + while True: + try: + trainer.fit(model, dm) + counter += 1 + except RuntimeError as e: + logging.error(f"Caught RuntimeError: {e}") + if str(e).find("CUDA out of memory") != -1: + if dm.hparams.batch_size > 4: + logging.error( + f"Retrying training with half the original batch size" + ) + gc.collect() + torch.cuda.empty_cache() + dm.hparams.batch_size = dm.hparams.batch_size // 2 + else: + logging.error("Batch size is already <= 4, giving up") + raise RuntimeError("Could not find small enough batch size") + else: + logging.error(f"Caught unknown error: {e}") + raise e + else: + break + + logger.info( + "Training finished, max memory used: " + f"{torch.cuda.max_memory_allocated(0)}" + ) + + trial.set_user_attr( + f"bag_{k}_checkpoint_path", checkpoint_callback.best_model_path + ) + checkpoint_paths.append(checkpoint_callback.best_model_path) + + if checkpoint_file is not None: + logger.info( + f"Symlinking {checkpoint_callback.best_model_path}" + f" to {checkpoint_file}" + ) + Path(checkpoint_file).symlink_to( + Path(checkpoint_callback.best_model_path).resolve() + ) + + results.append(model.best_objective) + logger.info(f" Result this bag: {model.best_objective}") + + current_features = model.selected_count() + + if is_dense and current_features < dm.n_annotations: + is_dense = False + if current_lambda / lambda_start < 2: + assert ( + f"lambda_start={lambda_start:.3f} " + "selected lambda might be too large.\n" + f"Features start to disappear at {current_lambda=:.3f}." + ) + print(f"Lambda = {current_lambda:.2e}, " + f"selected {current_features} features ") + + del dm + gc.collect() + torch.cuda.empty_cache() + + return current_features + + def build_lambda_scheduler(lambda_init, lambda_seq=None, path_multiplier=1.2, lambda_max=float("inf")): + #Build Lambda Sequence + if lambda_seq is not None: + lambda_seq = lambda_seq + else: + def _lambda_seq(start): + while start <= lambda_max: + yield start + start *= path_multiplier + + if lambda_init == "auto": + raise NotImplementedError("auto lambda initialization not yet implemented") + logger.info(" Building an auto-generated Lambda sequence.") + #TODO add M parameter to config and specify in DeepSetLasso init + lambda_start_ = ( + deeprvat_models.DeepSetLasso.lambda_start(M=10) #self.hparams['M']) + / config["model"]["config"]["optimizer"]["config"]["lr"] + / 10 # divide by 10 for initial training + ) + lambda_seq = _lambda_seq(lambda_start_) + else: + lambda_seq = _lambda_seq(lambda_init) + + # extract first value of lambda_seq + lambda_seq = iter(lambda_seq) + lambda_start = next(lambda_seq) + return lambda_start, lambda_seq + + results = [] + checkpoint_paths = [] + + if lasso: + lambda_init = config["model"].get("lambda_init", 1.64) + lambda_start, lambda_seq = build_lambda_scheduler(lambda_init) + + is_dense = True + counter = 0 + for current_lambda in itertools.chain([lambda_start], lambda_seq): + for k in range(n_bags): + current_features = run_lassotraining(k,current_lambda) + + if current_features == 0: + print('LASSO completed. Reached Sparsity; 0 selected features.') + break + else: + for k in range(n_bags): + run_training(k) + # Mark checkpoints with worst results to be dropped drop_n_bags = config["training"].get("drop_n_bags", None) if not debug else 1 if drop_n_bags is not None: - if config["model"]["config"]["metrics"].get("objective_mode", "max") == "max": + if config["model"]["config"]["metrics_train"].get("objective_mode", "max") == "max": min_result = sorted(results)[drop_n_bags] drop_bags = [(r < min_result) for r in results] else: @@ -636,7 +834,6 @@ def run_bagging( ) return final_result - @cli.command() @click.option("--debug", is_flag=True) @click.option("--training-gene-file", type=click.Path(exists=True)) @@ -790,7 +987,7 @@ def train( trial = study.best_trial logger.info(f'Best trial: {trial.user_attrs["user_id"]}') logger.info( - f' Mean {config["model"]["config"]["metrics"]["objective"]}: ' + f' Mean {config["model"]["config"]["metrics_val"]["objective"]}: ' f"{trial.value}" ) logger.info(f" Params:\n{pformat(trial.params)}") diff --git a/deeprvat/metrics.py b/deeprvat/metrics.py index 89fb9676..9d02dc1a 100644 --- a/deeprvat/metrics.py +++ b/deeprvat/metrics.py @@ -123,3 +123,24 @@ def __call__(self, preds, targets): bceloss = nn.BCEWithLogitsLoss() loss = bceloss(preds,targets) return loss + +class LassoLossVal: + def __init__(self): + pass + + def __call__(self, preds, y, lambda_, gamma, gamma_skip, l1_weights, l2_weights): + x = (F.mse_loss(preds, y) + + lambda_ * l1_weights + + gamma * l2_weights + + gamma_skip * l2_weights) + return x + +class LassoLossTrain: + def __init__(self): + pass + + def __call__(self, preds, y, gamma, gamma_skip, l2_weights): + x = (F.mse_loss(preds, y) + + gamma * l2_weights + + gamma_skip * l2_weights) + return x \ No newline at end of file diff --git a/deeprvat/utils.py b/deeprvat/utils.py index 2186786d..ab334b44 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -338,10 +338,8 @@ def sign_binary(x): def prox(v, u, *, lambda_, lambda_bar, M): """ - v has shape (m,) or (m, batches) - u has shape (k,) or (k, batches) - - supports GPU tensors + v of shape (m,) or (m, batches) + u of shape (k,) or (k, batches) """ onedim = len(v.shape) == 1 if onedim: diff --git a/example/config.yaml b/example/config.yaml index cd8119c0..4bc427cb 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -74,7 +74,16 @@ model: activation: LeakyReLU pool: max use_sigmoid: True - metrics: + metrics_train: + objective: MSE + objective_mode: min + loss: MSE + all: + MSE: {} + PearsonCorrTorch: {} + MAE: {} + RSquared: {} + metrics_val: objective: MSE objective_mode: min loss: MSE diff --git a/example/config_lasso.yaml b/example/config_lasso.yaml new file mode 100644 index 00000000..2f3ceee5 --- /dev/null +++ b/example/config_lasso.yaml @@ -0,0 +1,335 @@ +phenotypes: + Apolipoprotein_A: + correction_method: FDR + n_training_genes: 100 + baseline_phenotype: Apolipoprotein_A + Calcium: + correction_method: FDR + n_training_genes: 100 + baseline_phenotype: Calcium + +baseline_results: + - + base: baseline_results + type: plof/burden + - + base: baseline_results + type: missense/burden + - + base: baseline_results + type: plof/skat + - + base: baseline_results + type: missense/skat + +alpha: 0.05 + +n_burden_chunks: 2 +n_regression_chunks: 2 + +n_repeats: 2 + +do_scoretest: True + +training: + min_variant_count: 1 + n_bags: 1 + drop_n_bags: 0 + train_proportion: 0.8 + sample_with_replacement: False + dataloader_config: + batch_size: 1024 + num_workers: 32 + +pl_trainer: + gpus: 1 + precision: 16 + min_epochs: 3 + max_epochs: 5 + log_every_n_steps: 1 + check_val_every_n_epoch: 1 + +early_stopping: + mode: min + patience: 3 + min_delta: 0.00001 + verbose: True + +hyperparameter_optimization: + direction: maximize + n_trials: 1 + sampler: + type: TPESampler + config: {} + +model: + type: DeepSetLasso + model_collection: agg_models + checkpoint: combined_agg.pt + run_lasso: True + lambda_init: 1.64 + config: + phi_layers: 2 + phi_hidden_dim: 20 + rho_layers: 3 + rho_hidden_dim: 10 + activation: LeakyReLU + pool: max + use_sigmoid: True + metrics_train: + objective: MSE + objective_mode: min + loss: MSE + all: + MSE: {} + PearsonCorrTorch: {} + MAE: {} + RSquared: {} + metrics_val: + objective: MSE + objective_mode: min + loss: MSE + all: + MSE: {} + PearsonCorrTorch: {} + MAE: {} + RSquared: {} + optimizer: + type: AdamW + config: {} + +training_data: + gt_file: genotypes.h5 + variant_file: variants.parquet + dataset_config: + min_common_af: + MAF: 0.01 + phenotype_file: phenotypes.parquet + y_transformation: quantile_transform + x_phenotypes: + - age + - genetic_sex + - genetic_PC_1 + - genetic_PC_2 + - genetic_PC_3 + - genetic_PC_4 + - genetic_PC_5 + - genetic_PC_6 + - genetic_PC_7 + - genetic_PC_8 + - genetic_PC_9 + - genetic_PC_10 + - genetic_PC_11 + - genetic_PC_12 + - genetic_PC_13 + - genetic_PC_14 + - genetic_PC_15 + - genetic_PC_16 + - genetic_PC_17 + - genetic_PC_18 + - genetic_PC_19 + - genetic_PC_20 + annotation_file: annotations.parquet + annotations: + - MAF + - MAF_MB + - CADD_PHRED + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + use_common_variants: False + use_rare_variants: True + rare_embedding: + type: PaddedAnnotations + config: + annotations: + - MAF_MB + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + thresholds: + MAF: "MAF < 1e-2" + CADD_PHRED: "CADD_PHRED > 5" + verbose: True + low_memory: True + verbose: True + dataloader_config: + batch_size: 64 + num_workers: 8 + +data: + gt_file: genotypes.h5 + variant_file: variants.parquet + dataset_config: + min_common_af: + MAF: 0.01 + phenotype_file: phenotypes.parquet + y_transformation: quantile_transform + x_phenotypes: + - age + - genetic_sex + - genetic_PC_1 + - genetic_PC_2 + - genetic_PC_3 + - genetic_PC_4 + - genetic_PC_5 + - genetic_PC_6 + - genetic_PC_7 + - genetic_PC_8 + - genetic_PC_9 + - genetic_PC_10 + - genetic_PC_11 + - genetic_PC_12 + - genetic_PC_13 + - genetic_PC_14 + - genetic_PC_15 + - genetic_PC_16 + - genetic_PC_17 + - genetic_PC_18 + - genetic_PC_19 + - genetic_PC_20 + annotation_file: annotations.parquet + annotations: + - MAF + - MAF_MB + - CADD_PHRED + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + gene_file: protein_coding_genes.parquet + use_common_variants: False + use_rare_variants: True + rare_embedding: + type: PaddedAnnotations + config: + annotations: + - MAF_MB + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + thresholds: + MAF: "MAF < 1e-3" + CADD_PHRED: "CADD_PHRED > 5" + gene_file: protein_coding_genes.parquet + verbose: True + low_memory: True + verbose: True + dataloader_config: + batch_size: 16 + num_workers: 10