diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index a9d29228..0cfa8407 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -51,6 +51,22 @@ def get_burden( device: torch.device = torch.device("cpu"), skip_burdens=False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute burden scores for rare variants. + + Parameters: + - batch (Dict): A dictionary containing batched data from the DataLoader. + - agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + - device (torch.device): Device to perform computations on (default is CPU). + - skip_burdens (bool): Flag to skip burden computation (default is False). + + Notes: + - Checkpoint models all corresponding to the same repeat are averaged for that repeat. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing burden scores, target y phenotype values, and x phenotypes. + """ with torch.no_grad(): X = batch["rare_variant_annotations"].to(device) burden = [] @@ -74,6 +90,15 @@ def get_burden( def separate_parallel_results(results: List) -> Tuple[List, ...]: + """ + Separate results from running regression on each gene. + + Parameters: + - results (List): List of results obtained from regression analysis. + + Returns: + Tuple[List, ...]: Tuple of lists containing separated results of regressed_genes, betas, and pvals. + """ return tuple(map(list, zip(*results))) @@ -88,6 +113,18 @@ def make_dataset_( data_key="data", samples: Optional[List[int]] = None, ) -> Dataset: + """ + Create a dataset based on the configuration. + + Parameters: + - config (Dict): Configuration dictionary. + - debug (bool): Flag for debugging (default is False). + - data_key (str): Key for dataset configuration in the config dictionary (default is "data"). + - samples (List[int]): List of sample indices to include in the dataset (default is None). + + Returns: + Dataset: Loaded instance of the created dataset. + """ data_config = config[data_key] ds_pickled = data_config.get("pickled", None) @@ -122,6 +159,18 @@ def make_dataset_( @click.argument("config-file", type=click.Path(exists=True)) @click.argument("out-file", type=click.Path()) def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): + """ + Create a dataset based on the provided configuration and save to a pickle file. + + Parameters: + - debug (bool): Flag for debugging. + - data_key (str): Key for dataset configuration in the config dictionary (default is "data"). + - config_file (str): Path to the configuration file. + - out_file (str): Path to the output file. + + Returns: + Created dataset saved to output.pkl + """ with open(config_file) as f: config = yaml.safe_load(f) @@ -144,6 +193,29 @@ def compute_burdens_( compression_level: int = 1, skip_burdens: bool = False, ) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: + """ + Compute burdens using the PyTorch model for each repeat. + + Parameters: + - debug (bool): Flag for debugging. + - config (Dict): Configuration dictionary. + - ds (torch.utils.data.Dataset): Torch dataset. + - cache_dir (str): Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes. + - agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + - n_chunks (Optional[int]): Number of chunks to split data for processing (default is None). + - chunk (Optional[int]): Index of the chunk of data (default is None). + - device (torch.device): Device to perform computations on (default is CPU). + - bottleneck (bool): Flag to enable bottlenecking number of batches (default is False). + - compression_level (int): Blosc compressor compression level for zarr files (default is 1). + - skip_burdens (bool): Flag to skip burden computation (default is False). + + Notes: + - Checkpoint models all corresponding to the same repeat are averaged for that repeat. + + Returns: + Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: Tuple containing genes, burdens, target y phenotypes, and x phenotypes. + """ if not skip_burdens: logger.info("agg_models[*][*].reverse:") pprint( @@ -272,6 +344,17 @@ def load_one_model( checkpoint: str, device: torch.device = torch.device("cpu"), ): + """ + Load a single burden score computation model from a checkpoint file. + + Parameters: + - config (Dict): Configuration dictionary. + - checkpoint (str): Path to the model checkpoint file. + - device (torch.device): Device to load the model onto (default is CPU). + + Returns: + nn.Module: Loaded PyTorch model for burden score computation. + """ model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class.load_from_checkpoint( checkpoint, @@ -290,6 +373,17 @@ def load_one_model( def reverse_models( model_config_file: str, data_config_file: str, checkpoint_files: Tuple[str] ): + """ + Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations. + + Parameters: + - model_config_file (str): Path to the model configuration file. + - data_config_file (str): Path to the data configuration file. + - checkpoint_files (Tuple[str]): Paths to checkpoint files. + + Returns: + checkpoint.reverse file is created if the model should reverse the burden score output. + """ with open(model_config_file) as f: model_config = yaml.safe_load(f) @@ -351,6 +445,23 @@ def load_models( checkpoint_files: Tuple[str], device: torch.device = torch.device("cpu"), ) -> Dict[str, List[nn.Module]]: + """ + Load models from multiple checkpoints for multiple repeats. + + Parameters: + - config (Dict): Configuration dictionary. + - checkpoint_files (Tuple[str]): Paths to checkpoint files. + - device (torch.device): Device to load the models onto (default is CPU). + + Returns: + Dict[str, List[nn.Module]]: Dictionary of loaded PyTorch models for burden score computation for each repeat. + + Examples: + >>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}} + >>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth") + >>> load_models(config, checkpoint_files) + {'repeat_0': [MyModel(), MyModel()]} + """ logger.info("Loading models and checkpoints") if all( @@ -430,6 +541,28 @@ def compute_burdens( checkpoint_files: Tuple[str], out_dir: str, ): + """ + Compute burdens based on model and dataset provided. + + Parameters: + - debug (bool): Flag for debugging. + - bottleneck (bool): Flag to enable bottlenecking number of batches. + - n_chunks (Optional[int]): Number of chunks to split data for processing (default is None). + - chunk (Optional[int]): Index of the chunk of data (default is None). + - dataset_file (Optional[str]): Path to the dataset file, i.e. association_dataset.pkl. + - link_burdens (Optional[str]): Path to burden.zarr file to link. + - data_config_file (str): Path to the data configuration file. + - model_config_file (str): Path to the model configuration file. + - checkpoint_files (Tuple[str]): Paths to model checkpoint files. + - out_dir (str): Path to the output directory. + + Returns: + Computed burdens, corresponding genes, and targets are saved in the out_dir. + np.ndarray: Corresponding genes, saved as genes.npy + zarr.core.Array: Computed burdens, saved as burdens.zarr + zarr.core.Array: Target y phenotype, saved as y.zarr + zarr.core.Array: X phenotype, saved as x.zarr + """ if len(checkpoint_files) == 0: raise ValueError("At least one checkpoint file must be supplied") @@ -479,7 +612,19 @@ def compute_burdens( source_path.symlink_to(link_burdens) -def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score): +def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score, +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using the score test. + + Parameters: + - gene (str): Gene name. + - burdens (np.ndarray): Burden scores associated with the gene. + - model_score: Model for score test. + + Returns: + Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value. + """ burdens = burdens.reshape(burdens.shape[0], -1) logger.info(f"Burdens shape: {burdens.shape}") @@ -517,7 +662,21 @@ def regress_on_gene( x_pheno: np.ndarray, use_bias: bool, use_x_pheno: bool, -) -> Optional[Tuple[List[str], List[float], List[float]]]: +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using OLS. + + Parameters: + - gene (str): Gene name. + - X (np.ndarray): Burden score data. + - y (np.ndarray): Y phenotype data. + - x_pheno (np.ndarray): X phenotype data. + - use_bias (bool): Flag to include bias term. + - use_x_pheno (bool): Flag to include x phenotype data in regression. + + Returns: + Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value. + """ X = X.reshape(X.shape[0], -1) if np.all(np.abs(X) < 1e-6): logger.warning(f"Burden for gene {gene} is 0 for all samples; skipping") @@ -561,6 +720,23 @@ def regress_( use_x_pheno: bool = True, do_scoretest: bool = True, ) -> pd.DataFrame: + """ + Perform regression on multiple genes. + + Parameters: + - config (Dict): Configuration dictionary. + - use_bias (bool): Flag to include bias term when performing OLS regression. + - burdens (np.ndarray): Burden score data. + - y (np.ndarray): Y phenotype data. + - gene_indices (np.ndarray): Indices of genes. + - genes (pd.Series): Gene names. + - x_pheno (np.ndarray): X phenotype data. + - use_x_pheno (bool): Flag to include x phenotype data when performing OLS regression (default is True). + - do_scoretest (bool): Flag to use the scoretest from SEAK (default is True). + + Returns: + pd.DataFrame: DataFrame containing regression results on all genes. + """ assert len(gene_indices) == len(genes) logger.info(f"Computing associations") @@ -640,6 +816,25 @@ def regress( do_scoretest: bool, sample_file: Optional[str], ): + """ + Perform regression analysis. + + Parameters: + - debug (bool): Flag for debugging. + - chunk (int): Index of the chunk of data (default is 0). + - n_chunks (int): Number of chunks to split data for processing (default is 1). + - use_bias (bool): Flag to include bias term when performing OLS regression. + - gene_file (str): Path to the gene file. + - repeat (int): Index of the repeat (default is 0). + - config_file (str): Path to the configuration file. + - burden_dir (str): Path to the directory containing burdens.zarr file. + - out_dir (str): Path to the output directory. + - do_scoretest (bool): Flag to use the scoretest from SEAK. + - sample_file (Optional[str]): Path to the sample file. + + Returns: + Regression results saved to out_dir as "burden_associations_{chunk}.parquet" + """ logger.info("Loading saved burdens") y = zarr.open(Path(burden_dir) / "y.zarr")[:] burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat] @@ -708,6 +903,17 @@ def regress( def combine_regression_results( result_files: Tuple[str], out_file: str, model_name: Optional[str] ): + """ + Combine multiple regression result files. + + Parameters: + - model_name (str): Name of the regression model. + - result_files (List[str]): List of paths to regression result files. + - out_dir (str): Path to the output directory. + + Returns: + Concatenated regression results saved to a parquet file. + """ logger.info(f"Concatenating results") results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files]) diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index 410f4141..a0b5bb5f 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -41,6 +41,26 @@ def update_config( seed_genes_out: Optional[str], new_config_file: str, ): + """ + Select seed genes based on baseline results and update configuration file. + + Parameters: + - old_config_file (str): Path to the old configuration file. + - phenotype (Optional[str]): Phenotype to update in the configuration. + - seed_gene_dir (Optional[str]): Directory containing seed genes. + - baseline_results (Tuple[str]): Paths to baseline result files. + - baseline_results_out (Optional[str]): Path to save the updated baseline results. + - seed_genes_out (Optional[str]): Path to save the seed genes. + - new_config_file (str): Path to the new configuration file. + + Raises: + - ValueError: If neither --seed-gene-dir nor --baseline-results is specified. + + Returns: + Updated configuration file saved to new_config.yaml. + Selected seed genes saved to seed_genes_out.parquet. + Optionally, save baseline results to parquet file if baseline_results_out is specified. + """ if seed_gene_dir is None and len(baseline_results) == 0: raise ValueError( "One of --seed-gene-dir and --baseline-results " "must be specified" diff --git a/deeprvat/metrics.py b/deeprvat/metrics.py index 429ddfb3..5bb3bdc5 100644 --- a/deeprvat/metrics.py +++ b/deeprvat/metrics.py @@ -15,10 +15,23 @@ class RSquared: + """ + Calculates the R-squared (coefficient of determination) between predictions and targets. + """ def __init__(self): pass def __call__(self, preds: torch.tensor, targets: torch.tensor): + """ + Calculate R-squared value between two tensors. + + Parameters: + - preds (torch.tensor): Tensor containing predicted values. + - targets (torch.tensor): Tensor containing target values. + + Returns: + torch.tensor: R-squared value. + """ y_mean = torch.mean(targets) ss_tot = torch.sum(torch.square(targets - y_mean)) ss_res = torch.sum(torch.square(targets - preds)) @@ -26,10 +39,23 @@ def __call__(self, preds: torch.tensor, targets: torch.tensor): class PearsonCorr: + """ + Calculates the Pearson correlation coefficient between burdens and targets. + """ def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient. + + Parameters: + - burden (torch.tensor): Tensor containing burden values. + - y (torch.tensor): Tensor containing target values. + + Returns: + float: Pearson correlation coefficient. + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -48,10 +74,23 @@ def __call__(self, burden, y): class PearsonCorrTorch: + """ + Calculates the Pearson correlation coefficient between burdens and targets using PyTorch tensor operations. + """ def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient using PyTorch tensor operations. + + Parameters: + - burden (torch.tensor): Tensor containing burden values. + - y (torch.tensor): Tensor containing target values. + + Returns: + torch.tensor: Pearson correlation coefficient. + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -83,9 +122,22 @@ def calculate_pearsonr(self, x, y): class AveragePrecisionWithLogits: + """ + Calculates the average precision score between logits and targets. + """ def __init__(self): pass def __call__(self, logits, y): + """ + Calculate average precision score. + + Parameters: + - logits (torch.tensor): Tensor containing logits. + - y (torch.tensor): Tensor containing target values. + + Returns: + float: Average precision score. + """ y_scores = F.sigmoid(logits.detach()) return average_precision_score(y.detach().cpu().numpy(), y_scores.cpu().numpy()) diff --git a/deeprvat/utils.py b/deeprvat/utils.py index ef9bdf99..3b896451 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -24,6 +24,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply FDR correction to p-values in a DataFrame. + + Parameters: + - group (pd.DataFrame): DataFrame containing a "pval" column. + - alpha (float): Significance level. + + Returns: + pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + """ group = group.copy() rejected, pval_corrected = fdrcorrection(group["pval"], alpha=alpha) @@ -33,6 +43,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply Bonferroni correction to p-values in a DataFrame. + + Parameters: + - group (pd.DataFrame): DataFrame containing a "pval" column. + - alpha (float): Significance level. + + Returns: + pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + """ group = group.copy() pval_corrected = group["pval"] * len(group) @@ -42,6 +62,17 @@ def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "FDR"): + """ + Apply p-value correction to a DataFrame. + + Parameters: + - group (pd.DataFrame): DataFrame containing a column named "pval" with p-values to correct. + - alpha (float): Significance level. + - correction_type (str): Type of p-value correction. Options are 'FDR' (default) and 'Bonferroni'. + + Returns: + pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + """ if correction_type == "FDR": corrected = fdrcorrect_df(group, alpha) elif correction_type == "Bonferroni": @@ -57,6 +88,17 @@ def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "F def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = ""): + """ + Suggest hyperparameters using Optuna's suggest methods. + + Parameters: + - config (Dict): Configuration dictionary with hyperparameter specifications. + - trial (optuna.trial.Trial): Optuna trial instance. + - basename (str): Base name for hyperparameter suggestions. + + Returns: + Dict: Updated configuration with suggested hyperparameters. + """ config = copy.deepcopy(config) for k, cfg in config.items(): if isinstance(cfg, dict): @@ -75,6 +117,15 @@ def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = "") def compute_se(errors: np.ndarray) -> float: + """ + Compute standard error. + + Parameters: + - errors (np.ndarray): Array of errors. + + Returns: + float: Standard error. + """ mean_error = np.mean(errors) n = errors.shape[0] error_variance = np.mean((errors - mean_error) ** 2) / (n - 1) * n @@ -82,6 +133,15 @@ def compute_se(errors: np.ndarray) -> float: def standardize_series(x: pd.Series) -> pd.Series: + """ + Standardize a pandas Series. + + Parameters: + - x (pd.Series): Input Series. + + Returns: + pd.Series: Standardized Series. + """ x = x.astype(np.float32) mean = x.mean() variance = ((x - mean) ** 2).mean() @@ -91,7 +151,17 @@ def standardize_series(x: pd.Series) -> pd.Series: def my_quantile_transform(x, seed=1): """ - returns Gaussian quantile transformed values, "nan" are kept + Gaussian quantile transform for values in a pandas Series. + + Parameters: + - x: Input pandas Series. + - seed (int): Random seed. + + Notes: + - "nan" values are kept + + Returns: + pd.Series: Transformed Series. """ np.random.seed(seed) x_transform = x.copy().to_numpy() @@ -110,11 +180,32 @@ def my_quantile_transform(x, seed=1): def standardize_series_with_params(x: pd.Series, std, mean) -> pd.Series: + """ + Standardize a pandas Series using provided standard deviation and mean. + + Parameters: + - x (pd.Series): Input Series. + - std: Standard deviation to use for standardization. + - mean: Mean to use for standardization. + + Returns: + pd.Series: Standardized Series. + """ x = x.apply(lambda x: (x - mean) / std if x != 0 else 0) return x def calculate_mean_std(x: pd.Series, ignore_zero=True) -> pd.Series: + """ + Calculate mean and standard deviation of a pandas Series. + + Parameters: + - x (pd.Series): Input Series. + - ignore_zero (bool): Whether to ignore zero values in calculations (default is True). + + Returns: + Tuple[float, float]: Tuple of standard deviation and mean. + """ x = x.astype(np.float32) if ignore_zero: x = x[x != float(0)] @@ -130,6 +221,22 @@ def safe_merge( validate: str = "1:1", equal_row_nums: bool = False, ): + """ + Safely merge two pandas DataFrames. + + Parameters: + - left (pd.DataFrame): Left DataFrame. + - right (pd.DataFrame): Right DataFrame. + - validate (str): Validation method for the merge. + - equal_row_nums (bool): Whether to check if the row numbers are equal (default is False). + + Raises: + - ValueError: If left and right dataframe rows are unequal when 'equal_row_nums' is True. + - RuntimeError: If merged DataFrame has unequal row numbers compared to the left DataFrame. + + Returns: + pd.DataFrame: Merged DataFrame. + """ if equal_row_nums: try: assert len(left) == len(right) @@ -153,6 +260,15 @@ def safe_merge( def resolve_path_with_env(path: str) -> str: + """ + Resolve a path with environment variables. + + Parameters: + - path (str): Input path. + + Returns: + str: Resolved path. + """ path_split = [] head = path while head not in ("", "/"): @@ -168,6 +284,16 @@ def resolve_path_with_env(path: str) -> str: def copy_with_env(path: str, destination: str) -> str: + """ + Copy a file or directory to a destination with environment variables. + + Parameters: + - path (str): Input path (file or directory). + - destination (str): Destination path. + + Returns: + str: Resulting destination path. + """ destination = resolve_path_with_env(destination) if os.path.isfile(path): @@ -191,6 +317,16 @@ def copy_with_env(path: str, destination: str) -> str: def load_or_init(pickle_file: str, init_fn: Callable) -> Any: + """ + Load a pickled file or initialize an object. + + Parameters: + - pickle_file (str): Pickle file path. + - init_fn (Callable): Initialization function. + + Returns: + Any: Loaded or initialized object. + """ if pickle_file is not None and os.path.isfile(pickle_file): logger.info(f"Using pickled file {pickle_file}") with open(pickle_file, "rb") as f: @@ -204,6 +340,16 @@ def load_or_init(pickle_file: str, init_fn: Callable) -> Any: def remove_prefix(string, prefix): + """ + Remove a prefix from a string. + + Parameters: + - string (str): Input string. + - prefix (str): Prefix to remove. + + Returns: + str: String without the specified prefix. + """ if string.startswith(prefix): return string[len(prefix) :] return string @@ -218,6 +364,17 @@ def suggest_batch_size( }, buffer_bytes: int = 2_500_000_000, ): + """ + Suggest a batch size for a tensor based on available GPU memory. + + Parameters: + - tensor_shape (Iterable[int]): Shape of the tensor. + - example (Dict[str, Any]): Example dictionary with batch size, tensor shape, and max memory bytes. + - buffer_bytes (int): Buffer bytes to consider. + + Returns: + int: Suggested batch size for the given tensor shape and GPU memory. + """ gpu_mem_bytes = torch.cuda.get_device_properties(0).total_memory batch_size = math.floor( example["batch_size"]