Skip to content

Commit

Permalink
Additional functionalities:
Browse files Browse the repository at this point in the history
- added KLDiv, BCE, Quantile Loss
- fixed upsampling in dataloader
- code cleanup
  • Loading branch information
meyerkm committed Oct 10, 2023
1 parent 4ae7d96 commit 301b448
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 26 deletions.
23 changes: 10 additions & 13 deletions deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
PearsonCorrTorch,
RSquared,
AveragePrecisionWithLogits,
QuantileLoss,
KLDIVLoss,
BCELoss,
)

logging.basicConfig(
Expand All @@ -34,6 +37,9 @@
"PearsonCorrTorch": PearsonCorrTorch,
"BCEWithLogits": nn.BCEWithLogitsLoss,
"AveragePrecisionWithLogits": AveragePrecisionWithLogits,
"QuantileLoss": QuantileLoss,
"KLDiv": KLDIVLoss,
"BCELoss": BCELoss,
}

NORMALIZATION = {
Expand Down Expand Up @@ -169,7 +175,7 @@ def configure_callbacks(self):
return [ModelSummary()]

class Phenotype_classifier(pl.LightningModule):
def __init__(self, hparams, phenotypes, n_genes, gene_count):
def __init__(self, hparams, phenotypes, n_genes, gene_count, outdim=1):
super().__init__()
# pl.LightningModule already has attribute self.hparams,
# which is inherited from its parent class
Expand All @@ -189,11 +195,11 @@ def __init__(self, hparams, phenotypes, n_genes, gene_count):
self.pheno2id = dict(zip(phenotypes, range(len(phenotypes))))
dim = self.hparams_.n_covariates + self.gene_count
self.burden_pheno_embedding = self.get_embedding(len(phenotypes), dim)
self.geno_pheno = self.get_model("Classification", dim, 1,
self.geno_pheno = self.get_model("Classification", dim, outdim,
getattr(self.hparams_, "classification_layers", 1), 0)
else:
self.geno_pheno = nn.ModuleDict({
pheno: self.get_model("Classification", self.hparams_.n_covariates + self.hparams_.n_genes[pheno], 1,
pheno: self.get_model("Classification", self.hparams_.n_covariates + self.hparams_.n_genes[pheno], outdim,
getattr(self.hparams_, "classification_layers", 1), 0)
for pheno in self.hparams_.phenotypes
})
Expand Down Expand Up @@ -233,17 +239,13 @@ class DeepSetAgg(pl.LightningModule):
def __init__(
self,
deep_rvat: int,
pool_layer: str,
use_sigmoid: bool = False,
use_tanh: bool = False,
reverse: bool = False,
):
super().__init__()

self.deep_rvat = deep_rvat
self.pool_layer = pool_layer
self.use_sigmoid = use_sigmoid
self.use_tanh = use_tanh
self.reverse = reverse

def set_reverse(self, reverse: bool = True):
Expand All @@ -255,9 +257,7 @@ def forward(self, x):
x = self.deep_rvat(x)
# x.shape = samples x genes x latent
if self.reverse: x = -x

if self.use_sigmoid: x = torch.sigmoid(x)
if self.use_tanh: x = torch.tanh(x)
# burden_score
return x

Expand Down Expand Up @@ -289,7 +289,6 @@ def __init__(
self.normalization = getattr(self.hparams, "normalization", False)
self.activation = getattr(nn, getattr(self.hparams, "activation", "LeakyReLU"))()
self.use_sigmoid = getattr(self.hparams, "use_sigmoid", False)
self.use_tanh = getattr(self.hparams, "use_tanh", False)
self.reverse = getattr(self.hparams, "reverse", False)
self.pool_layer = getattr(self.hparams, "pool", "sum")
self.init_power_two = getattr(self.hparams, "first_layer_nearest_power_two", False)
Expand All @@ -316,10 +315,8 @@ def __init__(
else:
self.agg_model = DeepSetAgg(
deep_rvat=self.deep_rvat,
pool_layer=self.pool_layer,
use_sigmoid=self.use_sigmoid,
use_tanh=self.use_tanh,
reverse=self.reverse
reverse=self.reverse,
)
self.agg_model.train(False if self.hparams.stage == "val" else True)

Expand Down
31 changes: 19 additions & 12 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
PearsonCorrTorch,
RSquared,
AveragePrecisionWithLogits,
QuantileLoss,
KLDIVLoss,
BCELoss,
)
from deeprvat.utils import suggest_hparams

Expand All @@ -56,6 +59,9 @@
"PearsonCorrTorch": PearsonCorrTorch,
"BCEWithLogits": nn.BCEWithLogitsLoss,
"AveragePrecisionWithLogits": AveragePrecisionWithLogits,
"QuantileLoss": QuantileLoss,
"KLDiv": KLDIVLoss,
"BCELoss": BCELoss,
}
OPTIMIZERS = {
"sgd": optim.SGD,
Expand Down Expand Up @@ -282,7 +288,7 @@ def __getitem__(self, index):
start_idx = index * self.batch_size
end_idx = min(self.total_samples, start_idx + self.batch_size)
batch_samples = self.sample_order.iloc[start_idx:end_idx]
samples_by_pheno = batch_samples.groupby("phenotype")
samples_by_pheno = batch_samples.groupby("phenotype", observed=True)

result = dict()
for pheno, df in samples_by_pheno:
Expand Down Expand Up @@ -364,16 +370,14 @@ def __init__(
num_variants = pheno_data["input_tensor_zarr"].shape[-1]
if self.max_n_variants < num_variants: self.max_n_variants = num_variants

# TODO: Rewrite this for multiphenotype data
self.upsampling_factor = upsampling_factor
if self.upsampling_factor > 1:
raise NotImplementedError("Upsampling is not yet implemented")

logger.info(
f"Upsampling data with original sample number: {self.y.shape[0]}"
f"Upsampling data with original sample number: {pheno_data['y'].shape[0]}"
)
samples = self.upsample()
n_samples = self.samples.shape[0]
upsampled_indices = self.upsample(pheno_data['y'])
samples = np.append(np.arange(n_samples),upsampled_indices)
n_samples += upsampled_indices.shape[0]
logger.info(f"New sample number: {n_samples}")
else:
samples = np.arange(n_samples)
Expand Down Expand Up @@ -404,18 +408,18 @@ def __init__(
"cache_tensors",
)

def upsample(self) -> np.ndarray:
unique_values = self.y.unique()
def upsample(self,y_data) -> np.ndarray:
unique_values = y_data.unique()
if unique_values.size() != torch.Size([2]):
raise ValueError(
"Upsampling is only supported for binary y, "
f"but y has unique values {unique_values}")

class_indices = [(self.y == v).nonzero(as_tuple=True)[0] for v in unique_values]
class_indices = [(y_data == v).nonzero(as_tuple=True)[0] for v in unique_values]
class_sizes = [idx.shape[0] for idx in class_indices]
minority_class = 0 if class_sizes[0] < class_sizes[1] else 1
minority_indices = class_indices[minority_class].detach().numpy()
rng = np.random.default_rng()
rng = np.random.default_rng(seed=42)
upsampled_indices = rng.choice(
minority_indices,
size=(self.upsampling_factor - 1) * class_sizes[minority_class],
Expand All @@ -424,7 +428,7 @@ def upsample(self) -> np.ndarray:
logger.info(f"Minority class size: {class_sizes[minority_class]}")
logger.info(f"Increasing minority class size by {upsampled_indices.shape[0]}")

self.samples = upsampled_indices
return upsampled_indices

def train_dataloader(self):
logger.info(
Expand Down Expand Up @@ -807,6 +811,9 @@ def best_training_run(

trials = study.trials_dataframe().query('state == "COMPLETE"')
best_trial = trials.sort_values("value", ascending=False).iloc[0]
if best_trial['value'] == float("inf"):
best_trial = trials.sort_values("value", ascending=False).iloc[1]

best_trial_id = best_trial["user_attrs_user_id"]

logger.info(f"Best trial:\n{best_trial}")
Expand Down
33 changes: 32 additions & 1 deletion deeprvat/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys

import torch
from torch import nn
import torch.nn.functional as F
from scipy.stats.stats import pearsonr
from sklearn.metrics import average_precision_score
Expand Down Expand Up @@ -92,3 +92,34 @@ def __init__(self):
def __call__(self, logits, y):
y_scores = F.sigmoid(logits.detach())
return average_precision_score(y.detach().cpu().numpy(), y_scores.cpu().numpy())


class QuantileLoss:
def __init__(self):
pass

def __call__(self, preds, y):
q = 0.01
e = y - preds
return torch.mean(torch.max(q*e, (q-1)*e))

class KLDIVLoss:
def __init__(self):
pass

def __call__(self, preds, targets):
kl_loss = nn.KLDivLoss(reduction="batchmean")
preds = F.log_softmax(preds, dim=0) #requires predictions to be LOG probabilities
targets = F.softmax(targets, dim=0) #requires targets to be probabiliities
alpha = 1
output = alpha * kl_loss(preds, targets)
return output

class BCELoss:
def __init__(self):
pass

def __call__(self, preds, targets):
bceloss = nn.BCEWithLogitsLoss()
loss = bceloss(preds,targets)
return loss

0 comments on commit 301b448

Please sign in to comment.