diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 0cfa8407..bec7b0f0 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -54,18 +54,20 @@ def get_burden( """ Compute burden scores for rare variants. - Parameters: - - batch (Dict): A dictionary containing batched data from the DataLoader. - - agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation. - Each key in the dictionary corresponds to a respective repeat. - - device (torch.device): Device to perform computations on (default is CPU). - - skip_burdens (bool): Flag to skip burden computation (default is False). - - Notes: - - Checkpoint models all corresponding to the same repeat are averaged for that repeat. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing burden scores, target y phenotype values, and x phenotypes. + :param batch: A dictionary containing batched data from the DataLoader. + :type batch: Dict + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + :type agg_models: Dict[str, List[nn.Module]] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing burden scores, target y phenotype values, and x phenotypes. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ with torch.no_grad(): X = batch["rare_variant_annotations"].to(device) @@ -93,11 +95,10 @@ def separate_parallel_results(results: List) -> Tuple[List, ...]: """ Separate results from running regression on each gene. - Parameters: - - results (List): List of results obtained from regression analysis. - - Returns: - Tuple[List, ...]: Tuple of lists containing separated results of regressed_genes, betas, and pvals. + :param results: List of results obtained from regression analysis. + :type results: List + :return: Tuple of lists containing separated results of regressed_genes, betas, and pvals. + :rtype: Tuple[List, ...] """ return tuple(map(list, zip(*results))) @@ -116,14 +117,16 @@ def make_dataset_( """ Create a dataset based on the configuration. - Parameters: - - config (Dict): Configuration dictionary. - - debug (bool): Flag for debugging (default is False). - - data_key (str): Key for dataset configuration in the config dictionary (default is "data"). - - samples (List[int]): List of sample indices to include in the dataset (default is None). - - Returns: - Dataset: Loaded instance of the created dataset. + :param config: Configuration dictionary. + :type config: Dict + :param debug: Flag for debugging, defaults to False. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param samples: List of sample indices to include in the dataset, defaults to None. + :type samples: List[int] + :return: Loaded instance of the created dataset. + :rtype: Dataset """ data_config = config[data_key] @@ -162,14 +165,15 @@ def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): """ Create a dataset based on the provided configuration and save to a pickle file. - Parameters: - - debug (bool): Flag for debugging. - - data_key (str): Key for dataset configuration in the config dictionary (default is "data"). - - config_file (str): Path to the configuration file. - - out_file (str): Path to the output file. - - Returns: - Created dataset saved to output.pkl + :param debug: Flag for debugging. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param config_file: Path to the configuration file. + :type config_file: str + :param out_file: Path to the output file. + :type out_file: str + :return: Created dataset saved to out_file.pkl """ with open(config_file) as f: config = yaml.safe_load(f) @@ -194,27 +198,36 @@ def compute_burdens_( skip_burdens: bool = False, ) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: """ - Compute burdens using the PyTorch model for each repeat. - - Parameters: - - debug (bool): Flag for debugging. - - config (Dict): Configuration dictionary. - - ds (torch.utils.data.Dataset): Torch dataset. - - cache_dir (str): Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes. - - agg_models (Dict[str, List[nn.Module]]): Loaded PyTorch model(s) for each repeat used for burden computation. + Compute burdens using the PyTorch model for each repeat. + + :param debug: Flag for debugging. + :type debug: bool + :param config: Configuration dictionary. + :type config: Dict + :param ds: Torch dataset. + :type ds: torch.utils.data.Dataset + :param cache_dir: Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes. + :type cache_dir: str + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. Each key in the dictionary corresponds to a respective repeat. - - n_chunks (Optional[int]): Number of chunks to split data for processing (default is None). - - chunk (Optional[int]): Index of the chunk of data (default is None). - - device (torch.device): Device to perform computations on (default is CPU). - - bottleneck (bool): Flag to enable bottlenecking number of batches (default is False). - - compression_level (int): Blosc compressor compression level for zarr files (default is 1). - - skip_burdens (bool): Flag to skip burden computation (default is False). - - Notes: - - Checkpoint models all corresponding to the same repeat are averaged for that repeat. - - Returns: - Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: Tuple containing genes, burdens, target y phenotypes, and x phenotypes. + :type agg_models: Dict[str, List[nn.Module]] + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param bottleneck: Flag to enable bottlenecking number of batches, defaults to False. + :type bottleneck: bool + :param compression_level: Blosc compressor compression level for zarr files, defaults to 1. + :type compression_level: int + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing genes, burdens, target y phenotypes, and x phenotypes. + :rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ if not skip_burdens: logger.info("agg_models[*][*].reverse:") @@ -347,13 +360,14 @@ def load_one_model( """ Load a single burden score computation model from a checkpoint file. - Parameters: - - config (Dict): Configuration dictionary. - - checkpoint (str): Path to the model checkpoint file. - - device (torch.device): Device to load the model onto (default is CPU). - - Returns: - nn.Module: Loaded PyTorch model for burden score computation. + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint: Path to the model checkpoint file. + :type checkpoint: str + :param device: Device to load the model onto, defaults to "cpu". + :type device: torch.device + :return: Loaded PyTorch model for burden score computation. + :rtype: nn.Module """ model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class.load_from_checkpoint( @@ -376,13 +390,13 @@ def reverse_models( """ Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations. - Parameters: - - model_config_file (str): Path to the model configuration file. - - data_config_file (str): Path to the data configuration file. - - checkpoint_files (Tuple[str]): Paths to checkpoint files. - - Returns: - checkpoint.reverse file is created if the model should reverse the burden score output. + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :return: checkpoint.reverse file is created if the model should reverse the burden score output. """ with open(model_config_file) as f: model_config = yaml.safe_load(f) @@ -446,17 +460,19 @@ def load_models( device: torch.device = torch.device("cpu"), ) -> Dict[str, List[nn.Module]]: """ - Load models from multiple checkpoints for multiple repeats. + Load models from multiple checkpoints for multiple repeats. - Parameters: - - config (Dict): Configuration dictionary. - - checkpoint_files (Tuple[str]): Paths to checkpoint files. - - device (torch.device): Device to load the models onto (default is CPU). + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :param device: Device to load the models onto, defaults to "cpu". + :type device: torch.device + :return: Dictionary of loaded PyTorch models for burden score computation for each repeat. + :rtype: Dict[str, List[nn.Module]] - Returns: - Dict[str, List[nn.Module]]: Dictionary of loaded PyTorch models for burden score computation for each repeat. + :Examples: - Examples: >>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}} >>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth") >>> load_models(config, checkpoint_files) @@ -542,26 +558,33 @@ def compute_burdens( out_dir: str, ): """ - Compute burdens based on model and dataset provided. - - Parameters: - - debug (bool): Flag for debugging. - - bottleneck (bool): Flag to enable bottlenecking number of batches. - - n_chunks (Optional[int]): Number of chunks to split data for processing (default is None). - - chunk (Optional[int]): Index of the chunk of data (default is None). - - dataset_file (Optional[str]): Path to the dataset file, i.e. association_dataset.pkl. - - link_burdens (Optional[str]): Path to burden.zarr file to link. - - data_config_file (str): Path to the data configuration file. - - model_config_file (str): Path to the model configuration file. - - checkpoint_files (Tuple[str]): Paths to model checkpoint files. - - out_dir (str): Path to the output directory. - - Returns: - Computed burdens, corresponding genes, and targets are saved in the out_dir. - np.ndarray: Corresponding genes, saved as genes.npy - zarr.core.Array: Computed burdens, saved as burdens.zarr - zarr.core.Array: Target y phenotype, saved as y.zarr - zarr.core.Array: X phenotype, saved as x.zarr + Compute burdens based on the provided model and dataset. + + :param debug: Flag for debugging. + :type debug: bool + :param bottleneck: Flag to enable bottlenecking number of batches. + :type bottleneck: bool + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param dataset_file: Path to the dataset file, i.e., association_dataset.pkl. + :type dataset_file: Optional[str] + :param link_burdens: Path to burden.zarr file to link. + :type link_burdens: Optional[str] + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param checkpoint_files: Paths to model checkpoint files. + :type checkpoint_files: Tuple[str] + :param out_dir: Path to the output directory. + :type out_dir: str + :return: Corresonding genes, computed burdens, y phenotypes, and x phenotypes are saved in the out_dir. + :rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. """ if len(checkpoint_files) == 0: raise ValueError("At least one checkpoint file must be supplied") @@ -617,13 +640,14 @@ def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score, """ Perform regression on a gene using the score test. - Parameters: - - gene (str): Gene name. - - burdens (np.ndarray): Burden scores associated with the gene. - - model_score: Model for score test. - - Returns: - Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value. + :param gene: Gene name. + :type gene: str + :param burdens: Burden scores associated with the gene. + :type burdens: np.ndarray + :param model_score: Model for score test. + :type model_score: Any + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] """ burdens = burdens.reshape(burdens.shape[0], -1) logger.info(f"Burdens shape: {burdens.shape}") @@ -664,18 +688,22 @@ def regress_on_gene( use_x_pheno: bool, ) -> Tuple[List[str], List[float], List[float]]: """ - Perform regression on a gene using OLS. - - Parameters: - - gene (str): Gene name. - - X (np.ndarray): Burden score data. - - y (np.ndarray): Y phenotype data. - - x_pheno (np.ndarray): X phenotype data. - - use_bias (bool): Flag to include bias term. - - use_x_pheno (bool): Flag to include x phenotype data in regression. - - Returns: - Tuple[List[str], List[float], List[float]]: Tuple containing gene name, beta, and p-value. + Perform regression on a gene using Ordinary Least Squares (OLS). + + :param gene: Gene name. + :type gene: str + :param X: Burden score data. + :type X: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_bias: Flag to include bias term. + :type use_bias: bool + :param use_x_pheno: Flag to include x phenotype data in regression. + :type use_x_pheno: bool + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] """ X = X.reshape(X.shape[0], -1) if np.all(np.abs(X) < 1e-6): @@ -723,19 +751,26 @@ def regress_( """ Perform regression on multiple genes. - Parameters: - - config (Dict): Configuration dictionary. - - use_bias (bool): Flag to include bias term when performing OLS regression. - - burdens (np.ndarray): Burden score data. - - y (np.ndarray): Y phenotype data. - - gene_indices (np.ndarray): Indices of genes. - - genes (pd.Series): Gene names. - - x_pheno (np.ndarray): X phenotype data. - - use_x_pheno (bool): Flag to include x phenotype data when performing OLS regression (default is True). - - do_scoretest (bool): Flag to use the scoretest from SEAK (default is True). - - Returns: - pd.DataFrame: DataFrame containing regression results on all genes. + :param config: Configuration dictionary. + :type config: Dict + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param burdens: Burden score data. + :type burdens: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param gene_indices: Indices of genes. + :type gene_indices: np.ndarray + :param genes: Gene names. + :type genes: pd.Series + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_x_pheno: Flag to include x phenotype data when performing OLS regression, defaults to True. + :type use_x_pheno: bool + :param do_scoretest: Flag to use the scoretest from SEAK, defaults to True. + :type do_scoretest: bool + :return: DataFrame containing regression results on all genes. + :rtype: pd.DataFrame """ assert len(gene_indices) == len(genes) @@ -819,21 +854,29 @@ def regress( """ Perform regression analysis. - Parameters: - - debug (bool): Flag for debugging. - - chunk (int): Index of the chunk of data (default is 0). - - n_chunks (int): Number of chunks to split data for processing (default is 1). - - use_bias (bool): Flag to include bias term when performing OLS regression. - - gene_file (str): Path to the gene file. - - repeat (int): Index of the repeat (default is 0). - - config_file (str): Path to the configuration file. - - burden_dir (str): Path to the directory containing burdens.zarr file. - - out_dir (str): Path to the output directory. - - do_scoretest (bool): Flag to use the scoretest from SEAK. - - sample_file (Optional[str]): Path to the sample file. - - Returns: - Regression results saved to out_dir as "burden_associations_{chunk}.parquet" + :param debug: Flag for debugging. + :type debug: bool + :param chunk: Index of the chunk of data, defaults to 0. + :type chunk: int + :param n_chunks: Number of chunks to split data for processing, defaults to 1. + :type n_chunks: int + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param gene_file: Path to the gene file. + :type gene_file: str + :param repeat: Index of the repeat, defaults to 0. + :type repeat: int + :param config_file: Path to the configuration file. + :type config_file: str + :param burden_dir: Path to the directory containing burdens.zarr file. + :type burden_dir: str + :param out_dir: Path to the output directory. + :type out_dir: str + :param do_scoretest: Flag to use the scoretest from SEAK. + :type do_scoretest: bool + :param sample_file: Path to the sample file. + :type sample_file: Optional[str] + :return: Regression results saved to out_dir as "burden_associations_{chunk}.parquet" """ logger.info("Loading saved burdens") y = zarr.open(Path(burden_dir) / "y.zarr")[:] @@ -906,13 +949,13 @@ def combine_regression_results( """ Combine multiple regression result files. - Parameters: - - model_name (str): Name of the regression model. - - result_files (List[str]): List of paths to regression result files. - - out_dir (str): Path to the output directory. - - Returns: - Concatenated regression results saved to a parquet file. + :param result_files: List of paths to regression result files. + :type result_files: Tuple[str] + :param out_file: Path to the output file. + :type out_file: str + :param model_name: Name of the regression model. + :type model_name: Optional[str] + :return: Concatenated regression results saved to a parquet file. """ logger.info(f"Concatenating results") results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files]) diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index a0b5bb5f..1d4de29d 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -42,24 +42,26 @@ def update_config( new_config_file: str, ): """ - Select seed genes based on baseline results and update configuration file. - - Parameters: - - old_config_file (str): Path to the old configuration file. - - phenotype (Optional[str]): Phenotype to update in the configuration. - - seed_gene_dir (Optional[str]): Directory containing seed genes. - - baseline_results (Tuple[str]): Paths to baseline result files. - - baseline_results_out (Optional[str]): Path to save the updated baseline results. - - seed_genes_out (Optional[str]): Path to save the seed genes. - - new_config_file (str): Path to the new configuration file. - - Raises: - - ValueError: If neither --seed-gene-dir nor --baseline-results is specified. - - Returns: - Updated configuration file saved to new_config.yaml. - Selected seed genes saved to seed_genes_out.parquet. - Optionally, save baseline results to parquet file if baseline_results_out is specified. + Select seed genes based on baseline results and update the configuration file. + + :param old_config_file: Path to the old configuration file. + :type old_config_file: str + :param phenotype: Phenotype to update in the configuration. + :type phenotype: Optional[str] + :param seed_gene_dir: Directory containing seed genes. + :type seed_gene_dir: Optional[str] + :param baseline_results: Paths to baseline result files. + :type baseline_results: Tuple[str] + :param baseline_results_out: Path to save the updated baseline results. + :type baseline_results_out: Optional[str] + :param seed_genes_out: Path to save the seed genes. + :type seed_genes_out: Optional[str] + :param new_config_file: Path to the new configuration file. + :type new_config_file: str + :raises ValueError: If neither --seed-gene-dir nor --baseline-results is specified. + :return: Updated configuration file saved to new_config.yaml. + Selected seed genes saved to seed_genes_out.parquet. + Optionally, save baseline results to a parquet file if baseline_results_out is specified. """ if seed_gene_dir is None and len(baseline_results) == 0: raise ValueError( diff --git a/deeprvat/metrics.py b/deeprvat/metrics.py index 5bb3bdc5..c65c1679 100644 --- a/deeprvat/metrics.py +++ b/deeprvat/metrics.py @@ -25,12 +25,12 @@ def __call__(self, preds: torch.tensor, targets: torch.tensor): """ Calculate R-squared value between two tensors. - Parameters: - - preds (torch.tensor): Tensor containing predicted values. - - targets (torch.tensor): Tensor containing target values. - - Returns: - torch.tensor: R-squared value. + :param preds: Tensor containing predicted values. + :type preds: torch.tensor + :param targets: Tensor containing target values. + :type targets: torch.tensor + :return: R-squared value. + :rtype: torch.tensor """ y_mean = torch.mean(targets) ss_tot = torch.sum(torch.square(targets - y_mean)) @@ -49,12 +49,12 @@ def __call__(self, burden, y): """ Calculate Pearson correlation coefficient. - Parameters: - - burden (torch.tensor): Tensor containing burden values. - - y (torch.tensor): Tensor containing target values. - - Returns: - float: Pearson correlation coefficient. + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: float """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] @@ -84,12 +84,12 @@ def __call__(self, burden, y): """ Calculate Pearson correlation coefficient using PyTorch tensor operations. - Parameters: - - burden (torch.tensor): Tensor containing burden values. - - y (torch.tensor): Tensor containing target values. - - Returns: - torch.tensor: Pearson correlation coefficient. + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: torch.tensor """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] @@ -132,12 +132,12 @@ def __call__(self, logits, y): """ Calculate average precision score. - Parameters: - - logits (torch.tensor): Tensor containing logits. - - y (torch.tensor): Tensor containing target values. - - Returns: - float: Average precision score. + :param logits: Tensor containing logits. + :type logits: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Average precision score. + :rtype: float """ y_scores = F.sigmoid(logits.detach()) return average_precision_score(y.detach().cpu().numpy(), y_scores.cpu().numpy()) diff --git a/deeprvat/utils.py b/deeprvat/utils.py index 3b896451..af0411c9 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -25,14 +25,14 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: """ - Apply FDR correction to p-values in a DataFrame. - - Parameters: - - group (pd.DataFrame): DataFrame containing a "pval" column. - - alpha (float): Significance level. - - Returns: - pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + Apply False Discovery Rate (FDR) correction to p-values in a DataFrame. + + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame """ group = group.copy() @@ -46,12 +46,12 @@ def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: """ Apply Bonferroni correction to p-values in a DataFrame. - Parameters: - - group (pd.DataFrame): DataFrame containing a "pval" column. - - alpha (float): Significance level. - - Returns: - pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame """ group = group.copy() @@ -65,13 +65,14 @@ def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "F """ Apply p-value correction to a DataFrame. - Parameters: - - group (pd.DataFrame): DataFrame containing a column named "pval" with p-values to correct. - - alpha (float): Significance level. - - correction_type (str): Type of p-value correction. Options are 'FDR' (default) and 'Bonferroni'. - - Returns: - pd.DataFrame: Original DataFrame with additional columns "significant" and "pval_corrected". + :param group: DataFrame containing a column named "pval" with p-values to correct. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :param correction_type: Type of p-value correction. Options are 'FDR' (default) and 'Bonferroni'. + :type correction_type: str + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame """ if correction_type == "FDR": corrected = fdrcorrect_df(group, alpha) @@ -87,17 +88,18 @@ def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "F return corrected -def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = ""): +def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = "") -> Dict: """ Suggest hyperparameters using Optuna's suggest methods. - Parameters: - - config (Dict): Configuration dictionary with hyperparameter specifications. - - trial (optuna.trial.Trial): Optuna trial instance. - - basename (str): Base name for hyperparameter suggestions. - - Returns: - Dict: Updated configuration with suggested hyperparameters. + :param config: Configuration dictionary with hyperparameter specifications. + :type config: Dict + :param trial: Optuna trial instance. + :type trial: optuna.trial.Trial + :param basename: Base name for hyperparameter suggestions. + :type basename: str + :return: Updated configuration with suggested hyperparameters. + :rtype: Dict """ config = copy.deepcopy(config) for k, cfg in config.items(): @@ -120,11 +122,10 @@ def compute_se(errors: np.ndarray) -> float: """ Compute standard error. - Parameters: - - errors (np.ndarray): Array of errors. - - Returns: - float: Standard error. + :param errors: Array of errors. + :type errors: np.ndarray + :return: Standard error. + :rtype: float """ mean_error = np.mean(errors) n = errors.shape[0] @@ -136,11 +137,10 @@ def standardize_series(x: pd.Series) -> pd.Series: """ Standardize a pandas Series. - Parameters: - - x (pd.Series): Input Series. - - Returns: - pd.Series: Standardized Series. + :param x: Input Series. + :type x: pd.Series + :return: Standardized Series. + :rtype: pd.Series """ x = x.astype(np.float32) mean = x.mean() @@ -153,15 +153,15 @@ def my_quantile_transform(x, seed=1): """ Gaussian quantile transform for values in a pandas Series. - Parameters: - - x: Input pandas Series. - - seed (int): Random seed. + :param x: Input pandas Series. + :type x: pd.Series + :param seed: Random seed. + :type seed: int + :return: Transformed Series. + :rtype: pd.Series - Notes: - - "nan" values are kept - - Returns: - pd.Series: Transformed Series. + .. note:: + "nan" values are kept """ np.random.seed(seed) x_transform = x.copy().to_numpy() @@ -183,13 +183,12 @@ def standardize_series_with_params(x: pd.Series, std, mean) -> pd.Series: """ Standardize a pandas Series using provided standard deviation and mean. - Parameters: - - x (pd.Series): Input Series. - - std: Standard deviation to use for standardization. - - mean: Mean to use for standardization. - - Returns: - pd.Series: Standardized Series. + :param x: Input Series. + :type x: pd.Series + :param std: Standard deviation to use for standardization. + :param mean: Mean to use for standardization. + :return: Standardized Series. + :rtype: pd.Series """ x = x.apply(lambda x: (x - mean) / std if x != 0 else 0) return x @@ -199,12 +198,12 @@ def calculate_mean_std(x: pd.Series, ignore_zero=True) -> pd.Series: """ Calculate mean and standard deviation of a pandas Series. - Parameters: - - x (pd.Series): Input Series. - - ignore_zero (bool): Whether to ignore zero values in calculations (default is True). - - Returns: - Tuple[float, float]: Tuple of standard deviation and mean. + :param x: Input Series. + :type x: pd.Series + :param ignore_zero: Whether to ignore zero values in calculations, defaults to True. + :type ignore_zero: bool + :return: Tuple of standard deviation and mean. + :rtype: Tuple[float, float] """ x = x.astype(np.float32) if ignore_zero: @@ -224,18 +223,18 @@ def safe_merge( """ Safely merge two pandas DataFrames. - Parameters: - - left (pd.DataFrame): Left DataFrame. - - right (pd.DataFrame): Right DataFrame. - - validate (str): Validation method for the merge. - - equal_row_nums (bool): Whether to check if the row numbers are equal (default is False). - - Raises: - - ValueError: If left and right dataframe rows are unequal when 'equal_row_nums' is True. - - RuntimeError: If merged DataFrame has unequal row numbers compared to the left DataFrame. - - Returns: - pd.DataFrame: Merged DataFrame. + :param left: Left DataFrame. + :type left: pd.DataFrame + :param right: Right DataFrame. + :type right: pd.DataFrame + :param validate: Validation method for the merge. + :type validate: str + :param equal_row_nums: Whether to check if the row numbers are equal, defaults to False. + :type equal_row_nums: bool + :raises ValueError: If left and right dataframe rows are unequal when 'equal_row_nums' is True. + :raises RuntimeError: If merged DataFrame has unequal row numbers compared to the left DataFrame. + :return: Merged DataFrame. + :rtype: pd.DataFrame """ if equal_row_nums: try: @@ -263,11 +262,10 @@ def resolve_path_with_env(path: str) -> str: """ Resolve a path with environment variables. - Parameters: - - path (str): Input path. - - Returns: - str: Resolved path. + :param path: Input path. + :type path: str + :return: Resolved path. + :rtype: str """ path_split = [] head = path @@ -287,12 +285,12 @@ def copy_with_env(path: str, destination: str) -> str: """ Copy a file or directory to a destination with environment variables. - Parameters: - - path (str): Input path (file or directory). - - destination (str): Destination path. - - Returns: - str: Resulting destination path. + :param path: Input path (file or directory). + :type path: str + :param destination: Destination path. + :type destination: str + :return: Resulting destination path. + :rtype: str """ destination = resolve_path_with_env(destination) @@ -320,12 +318,12 @@ def load_or_init(pickle_file: str, init_fn: Callable) -> Any: """ Load a pickled file or initialize an object. - Parameters: - - pickle_file (str): Pickle file path. - - init_fn (Callable): Initialization function. - - Returns: - Any: Loaded or initialized object. + :param pickle_file: Pickle file path. + :type pickle_file: str + :param init_fn: Initialization function. + :type init_fn: Callable + :return: Loaded or initialized object. + :rtype: Any """ if pickle_file is not None and os.path.isfile(pickle_file): logger.info(f"Using pickled file {pickle_file}") @@ -343,12 +341,12 @@ def remove_prefix(string, prefix): """ Remove a prefix from a string. - Parameters: - - string (str): Input string. - - prefix (str): Prefix to remove. - - Returns: - str: String without the specified prefix. + :param string: Input string. + :type string: str + :param prefix: Prefix to remove. + :type prefix: str + :return: String without the specified prefix. + :rtype: str """ if string.startswith(prefix): return string[len(prefix) :] @@ -367,13 +365,14 @@ def suggest_batch_size( """ Suggest a batch size for a tensor based on available GPU memory. - Parameters: - - tensor_shape (Iterable[int]): Shape of the tensor. - - example (Dict[str, Any]): Example dictionary with batch size, tensor shape, and max memory bytes. - - buffer_bytes (int): Buffer bytes to consider. - - Returns: - int: Suggested batch size for the given tensor shape and GPU memory. + :param tensor_shape: Shape of the tensor. + :type tensor_shape: Iterable[int] + :param example: Example dictionary with batch size, tensor shape, and max memory bytes. + :type example: Dict[str, Any] + :param buffer_bytes: Buffer bytes to consider. + :type buffer_bytes: int + :return: Suggested batch size for the given tensor shape and GPU memory. + :rtype: int """ gpu_mem_bytes = torch.cuda.get_device_properties(0).total_memory batch_size = math.floor(