Skip to content

Commit

Permalink
added documentation and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Munzlinger committed Nov 29, 2023
1 parent 5efbeeb commit 89abc6b
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 94 deletions.
230 changes: 173 additions & 57 deletions deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,16 @@ def __init__(
**kwargs,
):
"""
config: dict, representing the content of config.yaml
n_annotations: dict, containing the number of annotations used each phenotype
n_covariates: dict, containing the number of covariates used each phenotype
n_genes: dict, containing the number of genes used each phenotype
phenotypes: list, containing phenotypes used during training
stage: str, containing a prefix, pointing to the dataset the model is operating on
Initializes BaseModel.
Args:
- config (dict): Represents the content of config.yaml.
- n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype.
- n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype.
- n_genes (Dict[str, int]): Contains the number of genes used for each phenotype.
- phenotypes (List[str]): Contains the phenotypes used during training.
- stage (str, optional): Contains a prefix indicating the dataset the model is operating on. Defaults to "train".
- **kwargs: Additional keyword arguments.
"""
super().__init__()
self.save_hyperparameters(config)
Expand Down Expand Up @@ -120,8 +124,19 @@ def configure_optimizers(self) -> torch.optim.Optimizer:

def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
"""
Function is called by the trainer during training and is expected to return
a loss from which we can compute backward passes.
Function called by trainer during training and returns the loss used
to update weights and biases.
Args:
- batch (dict): A dictionary containing the batch data.
- batch_idx (int): The index of the current batch.
Returns:
- torch.Tensor: The loss value computed to update weights and biases
based on the predictions.
Raises:
- RuntimeError: If NaNs are found in the training loss.
"""
# calls DeepSet.forward()
y_pred_by_pheno = self(batch)
Expand All @@ -148,17 +163,29 @@ def validation_step(self, batch: dict, batch_idx: int):
"""
During validation we do not compute backward passes, such that we can accumulate
phenotype predictions and evaluate them afterwards as a whole.
Args:
- batch (dict): A dictionary containing the validation batch data.
- batch_idx (int): The index of the current validation batch.
Returns:
- dict: A dictionary containing phenotype predictions ("y_pred_by_pheno")
and corresponding ground truth values ("y_by_pheno").
"""
y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()}
return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno}


def validation_epoch_end(
self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]]
):
):
"""
Evaluate accumulated phenotype predictions at the end of the validation epoch.
Args:
- prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions
and corresponding ground truth values obtained during the validation process.
"""
Evaluate all phenotype predictions in one go after accumulating them earlier
"""
y_pred_by_pheno = dict()
y_by_pheno = dict()
for result in prediction_y:
Expand Down Expand Up @@ -204,13 +231,25 @@ def test_step(self, batch: dict, batch_idx: int):
"""
During testing we do not compute backward passes, such that we can accumulate
phenotype predictions and evaluate them afterwards as a whole.
Args:
- batch (dict): A dictionary containing the validation batch data.
- batch_idx (int): The index of the current validation batch.
Returns:
- dict: A dictionary containing phenotype predictions ("y_pred")
and corresponding ground truth values ("y").
"""
return {"y_pred": self(batch), "y": batch["y"]}


def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
"""
Evaluate all phenotype predictions in one go after accumulating them earlier
Evaluate accumulated phenotype predictions at the end of the testing epoch.
Args:
- prediction_y (List[Dict[str, Dict[str, torch.Tensor]]]): A list of dictionaries containing accumulated phenotype predictions
and corresponding ground truth values obtained during the testing process.
"""
y_pred = torch.cat([p["y_pred"] for p in prediction_y])
y = torch.cat([p["y"] for p in prediction_y])
Expand Down Expand Up @@ -250,19 +289,20 @@ def __init__(
reverse: bool = False,
):
"""
n_annotations: int, number of annotations
phi_layers: int, number of layers in Phi
phi_hidden_dim: int, internal dimensionality of linear layers in Phi
rho_layers: int, number of layers in Rho
rho_hidden_dim: int, internal dimensionality of linear layers in Rho
activation: str, activation function used, has to match its name in torch.nn
pool: str, invariant aggregation function used to aggregate gene variants
possiblities: max, sum
output_dim: int, number of burden scores
dropout: float, propability, by which some parameters are set to 0
use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a
linear activation function to mimic association testing during training
reverse: booleon, to reverse the burden score. (used during association testing)
Initializes the DeepSetAgg module.
Args:
- n_annotations (int): Number of annotations.
- phi_layers (int): Number of layers in Phi.
- phi_hidden_dim (int): Internal dimensionality of linear layers in Phi.
- rho_layers (int): Number of layers in Rho.
- rho_hidden_dim (int): Internal dimensionality of linear layers in Rho.
- activation (str): Activation function used; should match its name in torch.nn.
- pool (str): Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'.
- output_dim (int, optional): Number of burden scores. Defaults to 1.
- dropout (Optional[float], optional): Probability by which some parameters are set to 0.
- use_sigmoid (bool, optional): Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False.
- reverse (bool, optional): Whether to reverse the burden score (used during association testing). Defaults to False.
"""
super().__init__()

Expand Down Expand Up @@ -306,12 +346,29 @@ def __init__(
# gene impairment scores and phenotypes linear
self.rho = nn.Sequential(OrderedDict(rho))

# reverse burden score during association testing if model predicts in negative space.
# compare associate.py, reverse_models() for further detail.

def set_reverse(self, reverse: bool = True):
"""
reverse burden score during association testing if model predicts in negative space.
Args:
- reverse (bool, optional): Indicates whether the 'reverse' attribute should be set to True or False.
Defaults to True.
Note:
- Compare associate.py, reverse_models() for further detail.
"""
self.reverse = reverse

def forward(self, x):
"""
Perform forward pass through the model.
Args:
- x (tensor): Batched input data
Returns:
- tensor: Burden scores
"""
x = self.phi(x.permute((0, 1, 3, 2)))
# x.shape = samples x genes x variants x phi_latent
if self.pool == "sum":
Expand All @@ -320,7 +377,7 @@ def forward(self, x):
x = torch.max(x, dim=2).values
# Now x.shape = samples x genes x phi_latent
x = self.rho(x)
# x.shape = samples x genes x rho_latent
# x.shape = samples x genes x 1
if self.reverse:
x = -x
if self.use_sigmoid:
Expand All @@ -346,17 +403,22 @@ def __init__(
reverse: bool = False,
**kwargs,
):

"""
config: dict, representing the content of config.yaml
n_annotations: dict, containing the number of annotations used each phenotype
n_covariates: dict, containing the number of covariates used each phenotype
n_genes: dict, containing the number of genes used each phenotype
phenotypes: list, containing phenotypes used during training
agg_model: pl.LightningModule / nn.Module, model used for burden computation
if module isnt given, it will be initialized
use_sigmoid: booleon, to project burden scores to [0, 1]. Also used as a
linear activation function to mimic association testing during training
reverse: booleon, to reverse the burden score. (used during association testing)
Initialize the DeepSet model.
Args:
- config (dict): Containing the content of config.yaml.
- n_annotations (Dict[str, int]): Contains the number of annotations used for each phenotype.
- n_covariates (Dict[str, int]): Contains the number of covariates used for each phenotype.
- n_genes (Dict[str, int]): Contains the number of genes used for each phenotype.
- phenotypes (List[str]): Contains the phenotypes used during training.
- agg_model (Optional[pl.LightningModule / nn.Module]): Model used for burden computation. If not provided,
it will be initialized.
- use_sigmoid (bool): Determines if burden scores should be projected to [0, 1]. Acts as a linear activation
function to mimic association testing during training.
- reverse (bool): Determines if the burden score should be reversed (used during association testing).
- **kwargs: Additional keyword arguments.
"""
super().__init__(
config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs
Expand All @@ -369,6 +431,9 @@ def __init__(
pool = get_hparam(self, "pool", "sum")
dropout = get_hparam(self, "dropout", None)

# self.agg_model compresses a batch
# from: samples x genes x annotations x variants
# to: samples x genes
if agg_model is not None:
self.agg_model = agg_model
else:
Expand All @@ -385,14 +450,12 @@ def __init__(
reverse=reverse,
)
self.agg_model.train(False if self.hparams.stage == "val" else True)

# dict of various linear layers used for phenotype prediction.
# Returns can be tested against ground truth data.
# self.agg_model compresses a batch
# from: samples x genes x annotations x variants;
# to: samples x genes
# afterwards genes are concatenated with covariates
# to: samples x (genes + covariates)


# dict of various linear layers used for phenotype prediction.
# Returns can be tested against ground truth data.
self.gene_pheno = nn.ModuleDict(
{
pheno: nn.Linear(
Expand All @@ -404,19 +467,18 @@ def __init__(

def forward(self, batch):
"""
batch: dict of phenotypes, each containing the following keys:
- indices: tensor,
indices for the underlying dataframe
- covariates: tensor,
covariates of samples e.g. age,
content: samples x covariates
- rare_variant_annotations: tensor,
actual annotated genomic variants,
content: samples x genes x annotations x variants
- y: tensor,
actual phenotypes (ground truth data)
Forward pass through the model.
Args:
- batch (dict): Dictionary of phenotypes, each containing the following keys:
- indices (tensor): Indices for the underlying dataframe.
- covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates.
- rare_variant_annotations (tensor): annotated genomic variants.
Content: samples x genes x annotations x variants.
- y (tensor): Actual phenotypes (ground truth data).
Returns:
- dict: Dictionary containing predicted phenotypes
"""
result = dict()
for pheno, this_batch in batch.items():
Expand All @@ -432,7 +494,21 @@ def forward(self, batch):


class LinearAgg(pl.LightningModule):
"""
To capture only linear effect, this model can be used as it only uses a single
linear layer without a non-linear activation function.
It still contains the gene impairment module used for burden computation.
"""

def __init__(self, n_annotations: int, pool: str, output_dim: int = 1):
"""
Initialize the LinearAgg model.
Args:
- n_annotations (int): Number of annotations.
- pool (str): Pooling method ("sum" or "max") to be used.
- output_dim (int, optional): Dimensionality of the output. Defaults to 1.
"""
super().__init__()

self.output_dim = output_dim
Expand All @@ -442,6 +518,15 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1):
self.linear = nn.Linear(n_annotations, self.output_dim)

def forward(self, x):
"""
Perform forward pass through the model.
Args:
- x (tensor): Batched input data
Returns:
- tensor: Burden scores
"""
x = self.linear(
x.permute((0, 1, 3, 2))
) # x.shape = samples x genes x variants x output_dim
Expand All @@ -454,6 +539,11 @@ def forward(self, x):


class TwoLayer(BaseModel):
"""
Wrapper class to capture linear effects. Inherits parameters from BaseModel,
which is where Pytorch Lightning specific functions like "training_step" or
"validation_epoch_end" can be found. Those functions are called in background by default.
"""
def __init__(
self,
config: dict,
Expand All @@ -463,6 +553,18 @@ def __init__(
agg_model: Optional[nn.Module] = None,
**kwargs,
):
"""
Initializes the TwoLayer model.
Args:
- config (dict): Represents the content of config.yaml.
- n_annotations (int): Number of annotations.
- n_covariates (int): Number of covariates.
- n_genes (int): Number of genes.
- agg_model (Optional[nn.Module]): Model used for burden computation. If not provided,
it will be initialized.
- **kwargs: Additional keyword arguments.
"""
super().__init__(config, n_annotations, n_covariates, n_genes, **kwargs)

logger.info("Initializing TwoLayer model with parameters:")
Expand All @@ -488,6 +590,20 @@ def __init__(
self.gene_pheno = nn.Linear(self.hparams.n_covariates + self.hparams.n_genes, 1)

def forward(self, batch):
"""
Forward pass through the model.
Args:
- batch (dict): Dictionary of phenotypes, each containing the following keys:
- indices (tensor): Indices for the underlying dataframe.
- covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates.
- rare_variant_annotations (tensor): annotated genomic variants.
Content: samples x genes x annotations x variants.
- y (tensor): Actual phenotypes (ground truth data).
Returns:
- dict: Dictionary containing predicted phenotypes
"""
# samples x genes x annotations x variants
x = batch["rare_variant_annotations"]
x = self.agg_model(x).squeeze(dim=2) # samples x genes
Expand Down
Loading

0 comments on commit 89abc6b

Please sign in to comment.