Skip to content

Commit

Permalink
Update docstrings for functions
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerkm committed Nov 27, 2023
1 parent f108ae9 commit eecf8bb
Show file tree
Hide file tree
Showing 4 changed files with 438 additions and 3 deletions.
210 changes: 208 additions & 2 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def get_burden(
device: torch.device = torch.device("cpu"),
skip_burdens=False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute burden scores for rare variants.
Parameters:
- batch (Dict): A dictionary containing batched data from the DataLoader.
- agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation.
Each key in the dictionary corresponds to a respective repeat.
- device (torch.device): Device to perform computations on (default is CPU).
- skip_burdens (bool): Flag to skip burden computation (default is False).
Notes:
- Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing burden scores, target y phenotype values, and x phenotypes.
"""
with torch.no_grad():
X = batch["rare_variant_annotations"].to(device)

Check warning on line 71 in deeprvat/deeprvat/associate.py

View workflow job for this annotation

GitHub Actions / docs-link-check

Block quote ends without a blank line; unexpected unindent.

Check warning on line 71 in deeprvat/deeprvat/associate.py

View workflow job for this annotation

GitHub Actions / docs-build

Block quote ends without a blank line; unexpected unindent.
burden = []
Expand All @@ -74,6 +90,15 @@ def get_burden(


def separate_parallel_results(results: List) -> Tuple[List, ...]:
"""
Separate results from running regression on each gene.
Parameters:
- results (List): List of results obtained from regression analysis.
Returns:
Tuple[List, ...]: Tuple of lists containing separated results of regressed_genes, betas, and pvals.
"""
return tuple(map(list, zip(*results)))


Expand All @@ -88,6 +113,18 @@ def make_dataset_(
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
Parameters:
- config (Dict): Configuration dictionary.
- debug (bool): Flag for debugging (default is False).
- data_key (str): Key for dataset configuration in the config dictionary (default is "data").
- samples (List[int]): List of sample indices to include in the dataset (default is None).
Returns:
Dataset: Loaded instance of the created dataset.
"""
data_config = config[data_key]

ds_pickled = data_config.get("pickled", None)
Expand Down Expand Up @@ -122,6 +159,18 @@ def make_dataset_(
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
"""
Create a dataset based on the provided configuration and save to a pickle file.
Parameters:
- debug (bool): Flag for debugging.
- data_key (str): Key for dataset configuration in the config dictionary (default is "data").
- config_file (str): Path to the configuration file.
- out_file (str): Path to the output file.
Returns:
Created dataset saved to output.pkl
"""
with open(config_file) as f:
config = yaml.safe_load(f)

Expand All @@ -144,6 +193,29 @@ def compute_burdens_(
compression_level: int = 1,
skip_burdens: bool = False,
) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]:
"""
Compute burdens using the PyTorch model for each repeat.
Parameters:
- debug (bool): Flag for debugging.
- config (Dict): Configuration dictionary.
- ds (torch.utils.data.Dataset): Torch dataset.
- cache_dir (str): Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes.
- agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation.
Each key in the dictionary corresponds to a respective repeat.
- n_chunks (Optional[int]): Number of chunks to split data for processing (default is None).
- chunk (Optional[int]): Index of the chunk of data (default is None).
- device (torch.device): Device to perform computations on (default is CPU).
- bottleneck (bool): Flag to enable bottlenecking number of batches (default is False).
- compression_level (int): Blosc compressor compression level for zarr files (default is 1).
- skip_burdens (bool): Flag to skip burden computation (default is False).
Notes:
- Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Returns:
Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: Tuple containing genes, burdens, target y phenotypes, and x phenotypes.
"""
if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand Down Expand Up @@ -272,6 +344,17 @@ def load_one_model(
checkpoint: str,
device: torch.device = torch.device("cpu"),
):
"""
Load a single burden score computation model from a checkpoint file.
Parameters:
- config (Dict): Configuration dictionary.
- checkpoint (str): Path to the model checkpoint file.
- device (torch.device): Device to load the model onto (default is CPU).
Returns:
nn.Module: Loaded PyTorch model for burden score computation.
"""
model_class = getattr(deeprvat_models, config["model"]["type"])
model = model_class.load_from_checkpoint(
checkpoint,
Expand All @@ -290,6 +373,17 @@ def load_one_model(
def reverse_models(
model_config_file: str, data_config_file: str, checkpoint_files: Tuple[str]
):
"""
Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations.
Parameters:
- model_config_file (str): Path to the model configuration file.
- data_config_file (str): Path to the data configuration file.
- checkpoint_files (Tuple[str]): Paths to checkpoint files.
Returns:
checkpoint.reverse file is created if the model should reverse the burden score output.
"""
with open(model_config_file) as f:
model_config = yaml.safe_load(f)

Expand Down Expand Up @@ -351,6 +445,23 @@ def load_models(
checkpoint_files: Tuple[str],
device: torch.device = torch.device("cpu"),
) -> Dict[str, List[nn.Module]]:
"""
Load models from multiple checkpoints for multiple repeats.
Parameters:
- config (Dict): Configuration dictionary.
- checkpoint_files (Tuple[str]): Paths to checkpoint files.
- device (torch.device): Device to load the models onto (default is CPU).
Returns:
Dict[str, List[nn.Module]]: Dictionary of loaded PyTorch models for burden score computation for each repeat.
Examples:
>>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}}
>>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth")
>>> load_models(config, checkpoint_files)
{'repeat_0': [MyModel(), MyModel()]}
"""
logger.info("Loading models and checkpoints")

if all(
Expand Down Expand Up @@ -430,6 +541,28 @@ def compute_burdens(
checkpoint_files: Tuple[str],
out_dir: str,
):
"""
Compute burdens based on model and dataset provided.
Parameters:
- debug (bool): Flag for debugging.
- bottleneck (bool): Flag to enable bottlenecking number of batches.
- n_chunks (Optional[int]): Number of chunks to split data for processing (default is None).
- chunk (Optional[int]): Index of the chunk of data (default is None).
- dataset_file (Optional[str]): Path to the dataset file, i.e. association_dataset.pkl.
- link_burdens (Optional[str]): Path to burden.zarr file to link.
- data_config_file (str): Path to the data configuration file.
- model_config_file (str): Path to the model configuration file.
- checkpoint_files (Tuple[str]): Paths to model checkpoint files.
- out_dir (str): Path to the output directory.
Returns:
Computed burdens, corresponding genes, and targets are saved in the out_dir.
np.ndarray: Corresponding genes, saved as genes.npy
zarr.core.Array: Computed burdens, saved as burdens.zarr
zarr.core.Array: Target y phenotype, saved as y.zarr
zarr.core.Array: X phenotype, saved as x.zarr
"""
if len(checkpoint_files) == 0:
raise ValueError("At least one checkpoint file must be supplied")

Expand Down Expand Up @@ -479,7 +612,19 @@ def compute_burdens(
source_path.symlink_to(link_burdens)


def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score):
def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score,
) -> Tuple[List[str], List[float], List[float]]:
"""
Perform regression on a gene using the score test.
Parameters:
- gene (str): Gene name.
- burdens (np.ndarray): Burden scores associated with the gene.
- model_score: Model for score test.
Returns:
Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value.
"""
burdens = burdens.reshape(burdens.shape[0], -1)
logger.info(f"Burdens shape: {burdens.shape}")

Expand Down Expand Up @@ -517,7 +662,21 @@ def regress_on_gene(
x_pheno: np.ndarray,
use_bias: bool,
use_x_pheno: bool,
) -> Optional[Tuple[List[str], List[float], List[float]]]:
) -> Tuple[List[str], List[float], List[float]]:
"""
Perform regression on a gene using OLS.
Parameters:
- gene (str): Gene name.
- X (np.ndarray): Burden score data.
- y (np.ndarray): Y phenotype data.
- x_pheno (np.ndarray): X phenotype data.
- use_bias (bool): Flag to include bias term.
- use_x_pheno (bool): Flag to include x phenotype data in regression.
Returns:
Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value.
"""
X = X.reshape(X.shape[0], -1)
if np.all(np.abs(X) < 1e-6):
logger.warning(f"Burden for gene {gene} is 0 for all samples; skipping")
Expand Down Expand Up @@ -561,6 +720,23 @@ def regress_(
use_x_pheno: bool = True,
do_scoretest: bool = True,
) -> pd.DataFrame:
"""
Perform regression on multiple genes.
Parameters:
- config (Dict): Configuration dictionary.
- use_bias (bool): Flag to include bias term when performing OLS regression.
- burdens (np.ndarray): Burden score data.
- y (np.ndarray): Y phenotype data.
- gene_indices (np.ndarray): Indices of genes.
- genes (pd.Series): Gene names.
- x_pheno (np.ndarray): X phenotype data.
- use_x_pheno (bool): Flag to include x phenotype data when performing OLS regression (default is True).
- do_scoretest (bool): Flag to use the scoretest from SEAK (default is True).
Returns:
pd.DataFrame: DataFrame containing regression results on all genes.
"""
assert len(gene_indices) == len(genes)

logger.info(f"Computing associations")
Expand Down Expand Up @@ -640,6 +816,25 @@ def regress(
do_scoretest: bool,
sample_file: Optional[str],
):
"""
Perform regression analysis.
Parameters:
- debug (bool): Flag for debugging.
- chunk (int): Index of the chunk of data (default is 0).
- n_chunks (int): Number of chunks to split data for processing (default is 1).
- use_bias (bool): Flag to include bias term when performing OLS regression.
- gene_file (str): Path to the gene file.
- repeat (int): Index of the repeat (default is 0).
- config_file (str): Path to the configuration file.
- burden_dir (str): Path to the directory containing burdens.zarr file.
- out_dir (str): Path to the output directory.
- do_scoretest (bool): Flag to use the scoretest from SEAK.
- sample_file (Optional[str]): Path to the sample file.
Returns:
Regression results saved to out_dir as "burden_associations_{chunk}.parquet"
"""
logger.info("Loading saved burdens")
y = zarr.open(Path(burden_dir) / "y.zarr")[:]
burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat]
Expand Down Expand Up @@ -708,6 +903,17 @@ def regress(
def combine_regression_results(
result_files: Tuple[str], out_file: str, model_name: Optional[str]
):
"""
Combine multiple regression result files.
Parameters:
- model_name (str): Name of the regression model.
- result_files (List[str]): List of paths to regression result files.
- out_dir (str): Path to the output directory.
Returns:
Concatenated regression results saved to a parquet file.
"""
logger.info(f"Concatenating results")
results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files])

Expand Down
20 changes: 20 additions & 0 deletions deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ def update_config(
seed_genes_out: Optional[str],
new_config_file: str,
):
"""
Select seed genes based on baseline results and update configuration file.
Parameters:
- old_config_file (str): Path to the old configuration file.
- phenotype (Optional[str]): Phenotype to update in the configuration.
- seed_gene_dir (Optional[str]): Directory containing seed genes.
- baseline_results (Tuple[str]): Paths to baseline result files.
- baseline_results_out (Optional[str]): Path to save the updated baseline results.
- seed_genes_out (Optional[str]): Path to save the seed genes.
- new_config_file (str): Path to the new configuration file.
Raises:
- ValueError: If neither --seed-gene-dir nor --baseline-results is specified.
Returns:
Updated configuration file saved to new_config.yaml.
Selected seed genes saved to seed_genes_out.parquet.
Optionally, save baseline results to parquet file if baseline_results_out is specified.
"""
if seed_gene_dir is None and len(baseline_results) == 0:
raise ValueError(
"One of --seed-gene-dir and --baseline-results " "must be specified"
Expand Down
Loading

0 comments on commit eecf8bb

Please sign in to comment.