Skip to content

Commit

Permalink
added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Munzlinger committed Nov 28, 2023
1 parent 07d111d commit 01d8d55
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 10 deletions.
129 changes: 122 additions & 7 deletions deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,13 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any):
else:
return default


class BaseModel(pl.LightningModule):
"""
Base class containing functions that will be called by PyTorch Lightning in the
background by default.
"""


def __init__(
self,
config: dict,
Expand All @@ -53,6 +58,14 @@ def __init__(
stage: str = "train",
**kwargs,
):
"""
config: dict, representing the content of config.yaml
n_annotations: dict, containing the number of annotations used each phenotype
n_covariates: dict, containing the number of covariates used each phenotype
n_genes: dict, containing the number of genes used each phenotype
phenotypes: list, containing phenotypes used during training
stage: str, containing a prefix, pointing to the dataset the model is operating on
"""
super().__init__()
self.save_hyperparameters(config)
self.save_hyperparameters(kwargs)
Expand All @@ -74,7 +87,12 @@ def __init__(
else:
raise ValueError("Unknown objective_mode configuration parameter")


def configure_optimizers(self) -> torch.optim.Optimizer:
"""
Function used to setup an optimizer and scheduler by their
parameters which are specified in config
"""
optimizer_config = self.hparams["optimizer"]
optimizer_class = getattr(torch.optim, optimizer_config["type"])
optimizer = optimizer_class(
Expand All @@ -99,10 +117,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
else:
return optimizer


def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
"""
Function is called by the trainer during training and is expected to return
a loss from which we can compute backward passes.
"""
# calls DeepSet.forward()
y_pred_by_pheno = self(batch)
results = dict()
# for all metrics we want to evaluate (specified in config)
for name, fn in self.metric_fns.items():
# compute mean distance in between ground truth and predicted score.
results[name] = torch.mean(
torch.stack(
[
Expand All @@ -112,22 +138,31 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
)
)
self.log(f"{self.hparams.stage}_{name}", results[name])

# set loss from which we compute backward passes
loss = results[self.hparams.metrics["loss"]]
if torch.any(torch.isnan(loss)):
raise RuntimeError("NaNs found in training loss")
return loss

def validation_step(self, batch: dict, batch_idx: int):
"""
During validation we do not compute backward passes, such that we can accumulate
phenotype predictions and evaluate them afterwards as a whole.
"""
y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()}
return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno}


def validation_epoch_end(
self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]]
):
"""
Evaluate all phenotype predictions in one go after accumulating them earlier
"""
y_pred_by_pheno = dict()
y_by_pheno = dict()
for result in prediction_y:
# create a dict for each phenotype that includes all respective predictions
pred = result["y_pred_by_pheno"]
for pheno, ys in pred.items():
y_pred_by_pheno[pheno] = torch.cat(
Expand All @@ -138,14 +173,16 @@ def validation_epoch_end(
ys,
]
)

# create a dict for each phenotype that includes the respective ground truth
target = result["y_by_pheno"]
for pheno, ys in target.items():
y_by_pheno[pheno] = torch.cat(
[y_by_pheno.get(pheno, torch.tensor([], device=self.device)), ys]
)

# create a dict for each phenotype that stores the respective loss
results = dict()
# for all metrics we want to evaluate (specified in config)
for name, fn in self.metric_fns.items():
results[name] = torch.mean(
torch.stack(
Expand All @@ -156,15 +193,25 @@ def validation_epoch_end(
)
)
self.log(f"val_{name}", results[name])

# consider all metrics only store the most min/max in self.best_objective
# to determine if progress was made in the last training epoch.
self.best_objective = self.objective_operation(
self.best_objective, results[self.hparams.metrics["objective"]].item()
)


def test_step(self, batch: dict, batch_idx: int):
"""
During testing we do not compute backward passes, such that we can accumulate
phenotype predictions and evaluate them afterwards as a whole.
"""
return {"y_pred": self(batch), "y": batch["y"]}


def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
"""
Evaluate all phenotype predictions in one go after accumulating them earlier
"""
y_pred = torch.cat([p["y_pred"] for p in prediction_y])
y = torch.cat([p["y"] for p in prediction_y])

Expand All @@ -180,8 +227,14 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
def configure_callbacks(self):
return [ModelSummary()]


class DeepSetAgg(pl.LightningModule):
"""
class contains the gene impairment module used for burden computation.
Variants are fed through an embedding network Phi, to compute a variant embedding
The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding.
Afterwards second network Rho, estimates the final gene impairment score.
All parameters of the gene impairment module are shared across genes and traits.
"""
def __init__(
self,
n_annotations: int,
Expand All @@ -196,6 +249,21 @@ def __init__(
use_sigmoid: bool = False,
reverse: bool = False,
):
"""
n_annotations: int, number of annotations
phi_layers: int, number of layers in Phi
phi_hidden_dim: int, internal dimensionality of linear layers in Phi
rho_layers: int, number of layers in Rho
rho_hidden_dim: int, internal dimensionality of linear layers in Rho
activation: str, activation function used, has to match its name in torch.nn
pool: str, invariant aggregation function used to aggregate gene variants
possiblities: max, sum
output_dim: int, number of burden scores
dropout: float, propability, by which some parameters are set to 0
use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a
linear activation function to mimic association testing during training
reverse: booleon, to reverse the burden score. (used during association testing)
"""
super().__init__()

self.output_dim = output_dim
Expand All @@ -205,6 +273,7 @@ def __init__(
self.use_sigmoid = use_sigmoid
self.reverse = reverse

# setup of Phi
input_dim = n_annotations
phi = []
for l in range(phi_layers):
Expand All @@ -216,10 +285,12 @@ def __init__(
input_dim = output_dim
self.phi = nn.Sequential(OrderedDict(phi))

# setup permutation-invariant aggregation function
if pool not in ("sum", "max"):
raise ValueError(f"Unknown pooling operation {pool}")
self.pool = pool

# setup of Rho
rho = []
for l in range(rho_layers - 1):
output_dim = rho_hidden_dim
Expand All @@ -231,8 +302,12 @@ def __init__(
rho.append(
(f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim))
)
# No final non-linear activation function to keep the relationship between
# gene impairment scores and phenotypes linear
self.rho = nn.Sequential(OrderedDict(rho))

# reverse burden score during association testing if model predicts in negative space.
# compare associate.py, reverse_models() for further detail.
def set_reverse(self, reverse: bool = True):
self.reverse = reverse

Expand All @@ -252,8 +327,13 @@ def forward(self, x):
x = torch.sigmoid(x)
return x


class DeepSet(BaseModel):
"""
Wrapper class for burden computation, that also does phenotype prediction.
It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions
like "training_step" or "validation_epoch_end" can be found.
Those functions are called in background by default.
"""
def __init__(
self,
config: dict,
Expand All @@ -266,6 +346,18 @@ def __init__(
reverse: bool = False,
**kwargs,
):
"""
config: dict, representing the content of config.yaml
n_annotations: dict, containing the number of annotations used each phenotype
n_covariates: dict, containing the number of covariates used each phenotype
n_genes: dict, containing the number of genes used each phenotype
phenotypes: list, containing phenotypes used during training
agg_model: pl.LightningModule / nn.Module, model used for burden computation
if module isnt given, it will be initialized
use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a
linear activation function to mimic association testing during training
reverse: booleon, to reverse the burden score. (used during association testing)
"""
super().__init__(
config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs
)
Expand Down Expand Up @@ -294,6 +386,13 @@ def __init__(
)
self.agg_model.train(False if self.hparams.stage == "val" else True)

# dict of various linear layers used for phenotype prediction.
# Returns can be tested against ground truth data.
# self.agg_model compresses a batch
# from: samples x genes x annotations x variants;
# to: samples x genes
# afterwards genes are concatenated with covariates
# to: samples x (genes + covariates)
self.gene_pheno = nn.ModuleDict(
{
pheno: nn.Linear(
Expand All @@ -302,8 +401,23 @@ def __init__(
for pheno in self.hparams.phenotypes
}
)

def forward(self, batch):
"""
batch: dict of phenotypes, each containing the following keys:
- indices: tensor,
indices for the underlying dataframe
- covariates: tensor,
covariates of samples e.g. age,
content: samples x covariates
- rare_variant_annotations: tensor,
actual annotated genomic variants,
content: samples x genes x annotations x variants
- y: tensor,
actual phenotypes (ground truth data)
"""
result = dict()
for pheno, this_batch in batch.items():
x = this_batch["rare_variant_annotations"]
Expand Down Expand Up @@ -380,3 +494,4 @@ def forward(self, batch):
x = torch.cat((batch["covariates"], x), dim=1)
x = self.gene_pheno(x).squeeze(dim=1) # samples
return x

Loading

0 comments on commit 01d8d55

Please sign in to comment.