Skip to content

Commit

Permalink
Bug fix lasso loss variable naming
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerkm committed Dec 12, 2023
1 parent 8b2c9d8 commit 5813428
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
12 changes: 8 additions & 4 deletions deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
torch.stack([ (fn(y_pred, batch[pheno]["y"],
self.hparams['gamma'],
self.hparams['gamma_skip'],
self.l2_regularization()))
self.l2_regularization(),
self.l2_regularization_skip()))
for pheno, y_pred in y_pred_by_pheno.items()]))
else:
results[name] = torch.mean(
Expand Down Expand Up @@ -487,7 +488,8 @@ def validation_epoch_end(self, prediction_y: List[Dict[str, Dict[str, torch.Tens
self.hparams['gamma'],
self.hparams['gamma_skip'],
self.l1_regularization_skip().item(),
self.l2_regularization()))
self.l2_regularization(),
self.l2_regularization_skip().item()))
for pheno, y_pred in y_pred_by_pheno.items()]))
else:
results[name] = torch.mean(
Expand Down Expand Up @@ -515,7 +517,8 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
self.hparams['gamma'],
self.hparams['gamma_skip'],
self.l1_regularization_skip().item(),
self.l2_regularization()))
self.l2_regularization(),
self.l2_regularization_skip().item()))
else:
results[name] = (fn(y_pred, y))
self.log(f"val_{name}", results[name])
Expand Down Expand Up @@ -682,6 +685,7 @@ def lambda_start(
factor=2,
):
"""Estimate when the model will start to sparsify."""
# TODO: fix this for group vs regular Lasso
def is_sparse(lambda_):
with torch.no_grad():
beta = self.skip.weight.data
Expand All @@ -707,7 +711,7 @@ def is_sparse(lambda_):

def l2_regularization(self):
"""
L2 regulatization of the MLPs in phi & rho without the first layer
L2 regularization of the MLPs in phi & rho without the first layer
which is bounded by the skip connection
"""
ans = 0
Expand Down
15 changes: 14 additions & 1 deletion deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def run_lassotraining(k,current_lambda):
n_covariates=dm.n_covariates,
n_genes=dm.n_genes,
gene_count=gene_count,
lambda_= current_lambda, #config['model']["config"]['lambda'],
lambda_=current_lambda, #config['model']["config"]['lambda'],
gamma=0.0,
gamma_skip=0.0,
M=10.0,
Expand All @@ -677,6 +677,19 @@ def run_lassotraining(k,current_lambda):
lambda_= current_lambda,
)

#####
#mock up lambda_schedule prediction from model attribute
# print('STARTING LAMBDA INIT FINDER NOW!!!!!!!!')
# est_lambda = model.lambda_start(M=10.0)
# start_l = (est_lambda
# / config["model"]["config"]["optimizer"]["config"]["lr"]
# / 10 # divide by 10 for initial training
# )
# print(f"estimated lambda: {est_lambda}")
# print(f"Estimated Start Lambda : {start_l}")
# import pdb; pdb.set_trace()
########

tb_log_dir = f"{log_dir}/bag_{k}/lambda_{current_lambda}"
logger.info(f" Writing TensorBoard logs to {tb_log_dir}")
tb_logger = TensorBoardLogger(log_dir, name=f"bag_{k}/lambda_{current_lambda}")
Expand Down
10 changes: 5 additions & 5 deletions deeprvat/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,19 @@ class LassoLossVal:
def __init__(self):
pass

def __call__(self, preds, y, lambda_, gamma, gamma_skip, l1_weights, l2_weights):
def __call__(self, preds, y, lambda_, gamma, gamma_skip, l1_skip_weights, l2_weights, l2_skip_weights):
x = (F.mse_loss(preds, y)
+ lambda_ * l1_weights
+ lambda_ * l1_skip_weights
+ gamma * l2_weights
+ gamma_skip * l2_weights)
+ gamma_skip * l2_skip_weights)
return x

class LassoLossTrain:
def __init__(self):
pass

def __call__(self, preds, y, gamma, gamma_skip, l2_weights):
def __call__(self, preds, y, gamma, gamma_skip, l2_weights, l2_skip_weights):
x = (F.mse_loss(preds, y)
+ gamma * l2_weights
+ gamma_skip * l2_weights)
+ gamma_skip * l2_skip_weights)
return x

0 comments on commit 5813428

Please sign in to comment.