diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index a9d29228..c5c26a94 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -51,6 +51,24 @@ def get_burden( device: torch.device = torch.device("cpu"), skip_burdens=False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute burden scores for rare variants. + + :param batch: A dictionary containing batched data from the DataLoader. + :type batch: Dict + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + :type agg_models: Dict[str, List[nn.Module]] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing burden scores, target y phenotype values, and x phenotypes. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ with torch.no_grad(): X = batch["rare_variant_annotations"].to(device) burden = [] @@ -74,6 +92,14 @@ def get_burden( def separate_parallel_results(results: List) -> Tuple[List, ...]: + """ + Separate results from running regression on each gene. + + :param results: List of results obtained from regression analysis. + :type results: List + :return: Tuple of lists containing separated results of regressed_genes, betas, and pvals. + :rtype: Tuple[List, ...] + """ return tuple(map(list, zip(*results))) @@ -88,6 +114,20 @@ def make_dataset_( data_key="data", samples: Optional[List[int]] = None, ) -> Dataset: + """ + Create a dataset based on the configuration. + + :param config: Configuration dictionary. + :type config: Dict + :param debug: Flag for debugging, defaults to False. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param samples: List of sample indices to include in the dataset, defaults to None. + :type samples: List[int] + :return: Loaded instance of the created dataset. + :rtype: Dataset + """ data_config = config[data_key] ds_pickled = data_config.get("pickled", None) @@ -122,6 +162,19 @@ def make_dataset_( @click.argument("config-file", type=click.Path(exists=True)) @click.argument("out-file", type=click.Path()) def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): + """ + Create a dataset based on the provided configuration and save to a pickle file. + + :param debug: Flag for debugging. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param config_file: Path to the configuration file. + :type config_file: str + :param out_file: Path to the output file. + :type out_file: str + :return: Created dataset saved to out_file.pkl + """ with open(config_file) as f: config = yaml.safe_load(f) @@ -144,6 +197,38 @@ def compute_burdens_( compression_level: int = 1, skip_burdens: bool = False, ) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: + """ + Compute burdens using the PyTorch model for each repeat. + + :param debug: Flag for debugging. + :type debug: bool + :param config: Configuration dictionary. + :type config: Dict + :param ds: Torch dataset. + :type ds: torch.utils.data.Dataset + :param cache_dir: Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes. + :type cache_dir: str + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + :type agg_models: Dict[str, List[nn.Module]] + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param bottleneck: Flag to enable bottlenecking number of batches, defaults to False. + :type bottleneck: bool + :param compression_level: Blosc compressor compression level for zarr files, defaults to 1. + :type compression_level: int + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing genes, burdens, target y phenotypes, and x phenotypes. + :rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ if not skip_burdens: logger.info("agg_models[*][*].reverse:") pprint( @@ -272,6 +357,18 @@ def load_one_model( checkpoint: str, device: torch.device = torch.device("cpu"), ): + """ + Load a single burden score computation model from a checkpoint file. + + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint: Path to the model checkpoint file. + :type checkpoint: str + :param device: Device to load the model onto, defaults to "cpu". + :type device: torch.device + :return: Loaded PyTorch model for burden score computation. + :rtype: nn.Module + """ model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class.load_from_checkpoint( checkpoint, @@ -290,6 +387,17 @@ def load_one_model( def reverse_models( model_config_file: str, data_config_file: str, checkpoint_files: Tuple[str] ): + """ + Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations. + + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :return: checkpoint.reverse file is created if the model should reverse the burden score output. + """ with open(model_config_file) as f: model_config = yaml.safe_load(f) @@ -351,6 +459,25 @@ def load_models( checkpoint_files: Tuple[str], device: torch.device = torch.device("cpu"), ) -> Dict[str, List[nn.Module]]: + """ + Load models from multiple checkpoints for multiple repeats. + + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :param device: Device to load the models onto, defaults to "cpu". + :type device: torch.device + :return: Dictionary of loaded PyTorch models for burden score computation for each repeat. + :rtype: Dict[str, List[nn.Module]] + + :Examples: + + >>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}} + >>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth") + >>> load_models(config, checkpoint_files) + {'repeat_0': [MyModel(), MyModel()]} + """ logger.info("Loading models and checkpoints") if all( @@ -430,6 +557,35 @@ def compute_burdens( checkpoint_files: Tuple[str], out_dir: str, ): + """ + Compute burdens based on the provided model and dataset. + + :param debug: Flag for debugging. + :type debug: bool + :param bottleneck: Flag to enable bottlenecking number of batches. + :type bottleneck: bool + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param dataset_file: Path to the dataset file, i.e., association_dataset.pkl. + :type dataset_file: Optional[str] + :param link_burdens: Path to burden.zarr file to link. + :type link_burdens: Optional[str] + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param checkpoint_files: Paths to model checkpoint files. + :type checkpoint_files: Tuple[str] + :param out_dir: Path to the output directory. + :type out_dir: str + :return: Corresonding genes, computed burdens, y phenotypes, and x phenotypes are saved in the out_dir. + :rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ if len(checkpoint_files) == 0: raise ValueError("At least one checkpoint file must be supplied") @@ -479,7 +635,23 @@ def compute_burdens( source_path.symlink_to(link_burdens) -def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score): +def regress_on_gene_scoretest( + gene: str, + burdens: np.ndarray, + model_score, +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using the score test. + + :param gene: Gene name. + :type gene: str + :param burdens: Burden scores associated with the gene. + :type burdens: np.ndarray + :param model_score: Model for score test. + :type model_score: Any + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] + """ burdens = burdens.reshape(burdens.shape[0], -1) logger.info(f"Burdens shape: {burdens.shape}") @@ -517,7 +689,25 @@ def regress_on_gene( x_pheno: np.ndarray, use_bias: bool, use_x_pheno: bool, -) -> Optional[Tuple[List[str], List[float], List[float]]]: +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using Ordinary Least Squares (OLS). + + :param gene: Gene name. + :type gene: str + :param X: Burden score data. + :type X: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_bias: Flag to include bias term. + :type use_bias: bool + :param use_x_pheno: Flag to include x phenotype data in regression. + :type use_x_pheno: bool + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] + """ X = X.reshape(X.shape[0], -1) if np.all(np.abs(X) < 1e-6): logger.warning(f"Burden for gene {gene} is 0 for all samples; skipping") @@ -561,6 +751,30 @@ def regress_( use_x_pheno: bool = True, do_scoretest: bool = True, ) -> pd.DataFrame: + """ + Perform regression on multiple genes. + + :param config: Configuration dictionary. + :type config: Dict + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param burdens: Burden score data. + :type burdens: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param gene_indices: Indices of genes. + :type gene_indices: np.ndarray + :param genes: Gene names. + :type genes: pd.Series + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_x_pheno: Flag to include x phenotype data when performing OLS regression, defaults to True. + :type use_x_pheno: bool + :param do_scoretest: Flag to use the scoretest from SEAK, defaults to True. + :type do_scoretest: bool + :return: DataFrame containing regression results on all genes. + :rtype: pd.DataFrame + """ assert len(gene_indices) == len(genes) logger.info(f"Computing associations") @@ -640,6 +854,33 @@ def regress( do_scoretest: bool, sample_file: Optional[str], ): + """ + Perform regression analysis. + + :param debug: Flag for debugging. + :type debug: bool + :param chunk: Index of the chunk of data, defaults to 0. + :type chunk: int + :param n_chunks: Number of chunks to split data for processing, defaults to 1. + :type n_chunks: int + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param gene_file: Path to the gene file. + :type gene_file: str + :param repeat: Index of the repeat, defaults to 0. + :type repeat: int + :param config_file: Path to the configuration file. + :type config_file: str + :param burden_dir: Path to the directory containing burdens.zarr file. + :type burden_dir: str + :param out_dir: Path to the output directory. + :type out_dir: str + :param do_scoretest: Flag to use the scoretest from SEAK. + :type do_scoretest: bool + :param sample_file: Path to the sample file. + :type sample_file: Optional[str] + :return: Regression results saved to out_dir as "burden_associations_{chunk}.parquet" + """ logger.info("Loading saved burdens") y = zarr.open(Path(burden_dir) / "y.zarr")[:] burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat] @@ -708,6 +949,17 @@ def regress( def combine_regression_results( result_files: Tuple[str], out_file: str, model_name: Optional[str] ): + """ + Combine multiple regression result files. + + :param result_files: List of paths to regression result files. + :type result_files: Tuple[str] + :param out_file: Path to the output file. + :type out_file: str + :param model_name: Name of the regression model. + :type model_name: Optional[str] + :return: Concatenated regression results saved to a parquet file. + """ logger.info(f"Concatenating results") results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files]) diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index 410f4141..1d4de29d 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -41,6 +41,28 @@ def update_config( seed_genes_out: Optional[str], new_config_file: str, ): + """ + Select seed genes based on baseline results and update the configuration file. + + :param old_config_file: Path to the old configuration file. + :type old_config_file: str + :param phenotype: Phenotype to update in the configuration. + :type phenotype: Optional[str] + :param seed_gene_dir: Directory containing seed genes. + :type seed_gene_dir: Optional[str] + :param baseline_results: Paths to baseline result files. + :type baseline_results: Tuple[str] + :param baseline_results_out: Path to save the updated baseline results. + :type baseline_results_out: Optional[str] + :param seed_genes_out: Path to save the seed genes. + :type seed_genes_out: Optional[str] + :param new_config_file: Path to the new configuration file. + :type new_config_file: str + :raises ValueError: If neither --seed-gene-dir nor --baseline-results is specified. + :return: Updated configuration file saved to new_config.yaml. + Selected seed genes saved to seed_genes_out.parquet. + Optionally, save baseline results to a parquet file if baseline_results_out is specified. + """ if seed_gene_dir is None and len(baseline_results) == 0: raise ValueError( "One of --seed-gene-dir and --baseline-results " "must be specified" diff --git a/deeprvat/metrics.py b/deeprvat/metrics.py index 429ddfb3..f7b74a01 100644 --- a/deeprvat/metrics.py +++ b/deeprvat/metrics.py @@ -15,10 +15,24 @@ class RSquared: + """ + Calculates the R-squared (coefficient of determination) between predictions and targets. + """ + def __init__(self): pass def __call__(self, preds: torch.tensor, targets: torch.tensor): + """ + Calculate R-squared value between two tensors. + + :param preds: Tensor containing predicted values. + :type preds: torch.tensor + :param targets: Tensor containing target values. + :type targets: torch.tensor + :return: R-squared value. + :rtype: torch.tensor + """ y_mean = torch.mean(targets) ss_tot = torch.sum(torch.square(targets - y_mean)) ss_res = torch.sum(torch.square(targets - preds)) @@ -26,10 +40,24 @@ def __call__(self, preds: torch.tensor, targets: torch.tensor): class PearsonCorr: + """ + Calculates the Pearson correlation coefficient between burdens and targets. + """ + def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient. + + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: float + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -48,10 +76,24 @@ def __call__(self, burden, y): class PearsonCorrTorch: + """ + Calculates the Pearson correlation coefficient between burdens and targets using PyTorch tensor operations. + """ + def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient using PyTorch tensor operations. + + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: torch.tensor + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -83,9 +125,23 @@ def calculate_pearsonr(self, x, y): class AveragePrecisionWithLogits: + """ + Calculates the average precision score between logits and targets. + """ + def __init__(self): pass def __call__(self, logits, y): + """ + Calculate average precision score. + + :param logits: Tensor containing logits. + :type logits: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Average precision score. + :rtype: float + """ y_scores = F.sigmoid(logits.detach()) return average_precision_score(y.detach().cpu().numpy(), y_scores.cpu().numpy()) diff --git a/deeprvat/utils.py b/deeprvat/utils.py index ef9bdf99..3ecad145 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -24,6 +24,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply False Discovery Rate (FDR) correction to p-values in a DataFrame. + + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ group = group.copy() rejected, pval_corrected = fdrcorrection(group["pval"], alpha=alpha) @@ -33,6 +43,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply Bonferroni correction to p-values in a DataFrame. + + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ group = group.copy() pval_corrected = group["pval"] * len(group) @@ -42,6 +62,18 @@ def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "FDR"): + """ + Apply p-value correction to a DataFrame. + + :param group: DataFrame containing a column named "pval" with p-values to correct. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :param correction_type: Type of p-value correction. Options are 'FDR' (default) and 'Bonferroni'. + :type correction_type: str + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ if correction_type == "FDR": corrected = fdrcorrect_df(group, alpha) elif correction_type == "Bonferroni": @@ -56,7 +88,21 @@ def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "F return corrected -def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = ""): +def suggest_hparams( + config: Dict, trial: optuna.trial.Trial, basename: str = "" +) -> Dict: + """ + Suggest hyperparameters using Optuna's suggest methods. + + :param config: Configuration dictionary with hyperparameter specifications. + :type config: Dict + :param trial: Optuna trial instance. + :type trial: optuna.trial.Trial + :param basename: Base name for hyperparameter suggestions. + :type basename: str + :return: Updated configuration with suggested hyperparameters. + :rtype: Dict + """ config = copy.deepcopy(config) for k, cfg in config.items(): if isinstance(cfg, dict): @@ -75,6 +121,14 @@ def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = "") def compute_se(errors: np.ndarray) -> float: + """ + Compute standard error. + + :param errors: Array of errors. + :type errors: np.ndarray + :return: Standard error. + :rtype: float + """ mean_error = np.mean(errors) n = errors.shape[0] error_variance = np.mean((errors - mean_error) ** 2) / (n - 1) * n @@ -82,6 +136,14 @@ def compute_se(errors: np.ndarray) -> float: def standardize_series(x: pd.Series) -> pd.Series: + """ + Standardize a pandas Series. + + :param x: Input Series. + :type x: pd.Series + :return: Standardized Series. + :rtype: pd.Series + """ x = x.astype(np.float32) mean = x.mean() variance = ((x - mean) ** 2).mean() @@ -91,7 +153,17 @@ def standardize_series(x: pd.Series) -> pd.Series: def my_quantile_transform(x, seed=1): """ - returns Gaussian quantile transformed values, "nan" are kept + Gaussian quantile transform for values in a pandas Series. + + :param x: Input pandas Series. + :type x: pd.Series + :param seed: Random seed. + :type seed: int + :return: Transformed Series. + :rtype: pd.Series + + .. note:: + "nan" values are kept """ np.random.seed(seed) x_transform = x.copy().to_numpy() @@ -110,11 +182,31 @@ def my_quantile_transform(x, seed=1): def standardize_series_with_params(x: pd.Series, std, mean) -> pd.Series: + """ + Standardize a pandas Series using provided standard deviation and mean. + + :param x: Input Series. + :type x: pd.Series + :param std: Standard deviation to use for standardization. + :param mean: Mean to use for standardization. + :return: Standardized Series. + :rtype: pd.Series + """ x = x.apply(lambda x: (x - mean) / std if x != 0 else 0) return x def calculate_mean_std(x: pd.Series, ignore_zero=True) -> pd.Series: + """ + Calculate mean and standard deviation of a pandas Series. + + :param x: Input Series. + :type x: pd.Series + :param ignore_zero: Whether to ignore zero values in calculations, defaults to True. + :type ignore_zero: bool + :return: Tuple of standard deviation and mean. + :rtype: Tuple[float, float] + """ x = x.astype(np.float32) if ignore_zero: x = x[x != float(0)] @@ -130,6 +222,22 @@ def safe_merge( validate: str = "1:1", equal_row_nums: bool = False, ): + """ + Safely merge two pandas DataFrames. + + :param left: Left DataFrame. + :type left: pd.DataFrame + :param right: Right DataFrame. + :type right: pd.DataFrame + :param validate: Validation method for the merge. + :type validate: str + :param equal_row_nums: Whether to check if the row numbers are equal, defaults to False. + :type equal_row_nums: bool + :raises ValueError: If left and right dataframe rows are unequal when 'equal_row_nums' is True. + :raises RuntimeError: If merged DataFrame has unequal row numbers compared to the left DataFrame. + :return: Merged DataFrame. + :rtype: pd.DataFrame + """ if equal_row_nums: try: assert len(left) == len(right) @@ -153,6 +261,14 @@ def safe_merge( def resolve_path_with_env(path: str) -> str: + """ + Resolve a path with environment variables. + + :param path: Input path. + :type path: str + :return: Resolved path. + :rtype: str + """ path_split = [] head = path while head not in ("", "/"): @@ -168,6 +284,16 @@ def resolve_path_with_env(path: str) -> str: def copy_with_env(path: str, destination: str) -> str: + """ + Copy a file or directory to a destination with environment variables. + + :param path: Input path (file or directory). + :type path: str + :param destination: Destination path. + :type destination: str + :return: Resulting destination path. + :rtype: str + """ destination = resolve_path_with_env(destination) if os.path.isfile(path): @@ -191,6 +317,16 @@ def copy_with_env(path: str, destination: str) -> str: def load_or_init(pickle_file: str, init_fn: Callable) -> Any: + """ + Load a pickled file or initialize an object. + + :param pickle_file: Pickle file path. + :type pickle_file: str + :param init_fn: Initialization function. + :type init_fn: Callable + :return: Loaded or initialized object. + :rtype: Any + """ if pickle_file is not None and os.path.isfile(pickle_file): logger.info(f"Using pickled file {pickle_file}") with open(pickle_file, "rb") as f: @@ -204,6 +340,16 @@ def load_or_init(pickle_file: str, init_fn: Callable) -> Any: def remove_prefix(string, prefix): + """ + Remove a prefix from a string. + + :param string: Input string. + :type string: str + :param prefix: Prefix to remove. + :type prefix: str + :return: String without the specified prefix. + :rtype: str + """ if string.startswith(prefix): return string[len(prefix) :] return string @@ -218,6 +364,18 @@ def suggest_batch_size( }, buffer_bytes: int = 2_500_000_000, ): + """ + Suggest a batch size for a tensor based on available GPU memory. + + :param tensor_shape: Shape of the tensor. + :type tensor_shape: Iterable[int] + :param example: Example dictionary with batch size, tensor shape, and max memory bytes. + :type example: Dict[str, Any] + :param buffer_bytes: Buffer bytes to consider. + :type buffer_bytes: int + :return: Suggested batch size for the given tensor shape and GPU memory. + :rtype: int + """ gpu_mem_bytes = torch.cuda.get_device_properties(0).total_memory batch_size = math.floor( example["batch_size"]