Skip to content

Commit

Permalink
Merge branch 'main' into feature-make-the-preprocessing-less-ukbb-spe…
Browse files Browse the repository at this point in the history
…cific
  • Loading branch information
endast committed Nov 28, 2023
2 parents becea82 + 4dd98a9 commit 7c2b9a6
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 4 deletions.
256 changes: 254 additions & 2 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ def get_burden(
device: torch.device = torch.device("cpu"),
skip_burdens=False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute burden scores for rare variants.
:param batch: A dictionary containing batched data from the DataLoader.
:type batch: Dict
:param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation.
Each key in the dictionary corresponds to a respective repeat.
:type agg_models: Dict[str, List[nn.Module]]
:param device: Device to perform computations on, defaults to "cpu".
:type device: torch.device
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:return: Tuple containing burden scores, target y phenotype values, and x phenotypes.
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
with torch.no_grad():
X = batch["rare_variant_annotations"].to(device)
burden = []
Expand All @@ -74,6 +92,14 @@ def get_burden(


def separate_parallel_results(results: List) -> Tuple[List, ...]:
"""
Separate results from running regression on each gene.
:param results: List of results obtained from regression analysis.
:type results: List
:return: Tuple of lists containing separated results of regressed_genes, betas, and pvals.
:rtype: Tuple[List, ...]
"""
return tuple(map(list, zip(*results)))


Expand All @@ -88,6 +114,20 @@ def make_dataset_(
data_key="data",
samples: Optional[List[int]] = None,
) -> Dataset:
"""
Create a dataset based on the configuration.
:param config: Configuration dictionary.
:type config: Dict
:param debug: Flag for debugging, defaults to False.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param samples: List of sample indices to include in the dataset, defaults to None.
:type samples: List[int]
:return: Loaded instance of the created dataset.
:rtype: Dataset
"""
data_config = config[data_key]

ds_pickled = data_config.get("pickled", None)
Expand Down Expand Up @@ -122,6 +162,19 @@ def make_dataset_(
@click.argument("config-file", type=click.Path(exists=True))
@click.argument("out-file", type=click.Path())
def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str):
"""
Create a dataset based on the provided configuration and save to a pickle file.
:param debug: Flag for debugging.
:type debug: bool
:param data_key: Key for dataset configuration in the config dictionary, defaults to "data".
:type data_key: str
:param config_file: Path to the configuration file.
:type config_file: str
:param out_file: Path to the output file.
:type out_file: str
:return: Created dataset saved to out_file.pkl
"""
with open(config_file) as f:
config = yaml.safe_load(f)

Expand All @@ -144,6 +197,38 @@ def compute_burdens_(
compression_level: int = 1,
skip_burdens: bool = False,
) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]:
"""
Compute burdens using the PyTorch model for each repeat.
:param debug: Flag for debugging.
:type debug: bool
:param config: Configuration dictionary.
:type config: Dict
:param ds: Torch dataset.
:type ds: torch.utils.data.Dataset
:param cache_dir: Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes.
:type cache_dir: str
:param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation.
Each key in the dictionary corresponds to a respective repeat.
:type agg_models: Dict[str, List[nn.Module]]
:param n_chunks: Number of chunks to split data for processing, defaults to None.
:type n_chunks: Optional[int]
:param chunk: Index of the chunk of data, defaults to None.
:type chunk: Optional[int]
:param device: Device to perform computations on, defaults to "cpu".
:type device: torch.device
:param bottleneck: Flag to enable bottlenecking number of batches, defaults to False.
:type bottleneck: bool
:param compression_level: Blosc compressor compression level for zarr files, defaults to 1.
:type compression_level: int
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:return: Tuple containing genes, burdens, target y phenotypes, and x phenotypes.
:rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
if not skip_burdens:
logger.info("agg_models[*][*].reverse:")
pprint(
Expand Down Expand Up @@ -272,6 +357,18 @@ def load_one_model(
checkpoint: str,
device: torch.device = torch.device("cpu"),
):
"""
Load a single burden score computation model from a checkpoint file.
:param config: Configuration dictionary.
:type config: Dict
:param checkpoint: Path to the model checkpoint file.
:type checkpoint: str
:param device: Device to load the model onto, defaults to "cpu".
:type device: torch.device
:return: Loaded PyTorch model for burden score computation.
:rtype: nn.Module
"""
model_class = getattr(deeprvat_models, config["model"]["type"])
model = model_class.load_from_checkpoint(
checkpoint,
Expand All @@ -290,6 +387,17 @@ def load_one_model(
def reverse_models(
model_config_file: str, data_config_file: str, checkpoint_files: Tuple[str]
):
"""
Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations.
:param model_config_file: Path to the model configuration file.
:type model_config_file: str
:param data_config_file: Path to the data configuration file.
:type data_config_file: str
:param checkpoint_files: Paths to checkpoint files.
:type checkpoint_files: Tuple[str]
:return: checkpoint.reverse file is created if the model should reverse the burden score output.
"""
with open(model_config_file) as f:
model_config = yaml.safe_load(f)

Expand Down Expand Up @@ -351,6 +459,25 @@ def load_models(
checkpoint_files: Tuple[str],
device: torch.device = torch.device("cpu"),
) -> Dict[str, List[nn.Module]]:
"""
Load models from multiple checkpoints for multiple repeats.
:param config: Configuration dictionary.
:type config: Dict
:param checkpoint_files: Paths to checkpoint files.
:type checkpoint_files: Tuple[str]
:param device: Device to load the models onto, defaults to "cpu".
:type device: torch.device
:return: Dictionary of loaded PyTorch models for burden score computation for each repeat.
:rtype: Dict[str, List[nn.Module]]
:Examples:
>>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}}
>>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth")
>>> load_models(config, checkpoint_files)
{'repeat_0': [MyModel(), MyModel()]}
"""
logger.info("Loading models and checkpoints")

if all(
Expand Down Expand Up @@ -430,6 +557,35 @@ def compute_burdens(
checkpoint_files: Tuple[str],
out_dir: str,
):
"""
Compute burdens based on the provided model and dataset.
:param debug: Flag for debugging.
:type debug: bool
:param bottleneck: Flag to enable bottlenecking number of batches.
:type bottleneck: bool
:param n_chunks: Number of chunks to split data for processing, defaults to None.
:type n_chunks: Optional[int]
:param chunk: Index of the chunk of data, defaults to None.
:type chunk: Optional[int]
:param dataset_file: Path to the dataset file, i.e., association_dataset.pkl.
:type dataset_file: Optional[str]
:param link_burdens: Path to burden.zarr file to link.
:type link_burdens: Optional[str]
:param data_config_file: Path to the data configuration file.
:type data_config_file: str
:param model_config_file: Path to the model configuration file.
:type model_config_file: str
:param checkpoint_files: Paths to model checkpoint files.
:type checkpoint_files: Tuple[str]
:param out_dir: Path to the output directory.
:type out_dir: str
:return: Corresonding genes, computed burdens, y phenotypes, and x phenotypes are saved in the out_dir.
:rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
"""
if len(checkpoint_files) == 0:
raise ValueError("At least one checkpoint file must be supplied")

Expand Down Expand Up @@ -479,7 +635,23 @@ def compute_burdens(
source_path.symlink_to(link_burdens)


def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score):
def regress_on_gene_scoretest(
gene: str,
burdens: np.ndarray,
model_score,
) -> Tuple[List[str], List[float], List[float]]:
"""
Perform regression on a gene using the score test.
:param gene: Gene name.
:type gene: str
:param burdens: Burden scores associated with the gene.
:type burdens: np.ndarray
:param model_score: Model for score test.
:type model_score: Any
:return: Tuple containing gene name, beta, and p-value.
:rtype: Tuple[List[str], List[float], List[float]]
"""
burdens = burdens.reshape(burdens.shape[0], -1)
logger.info(f"Burdens shape: {burdens.shape}")

Expand Down Expand Up @@ -517,7 +689,25 @@ def regress_on_gene(
x_pheno: np.ndarray,
use_bias: bool,
use_x_pheno: bool,
) -> Optional[Tuple[List[str], List[float], List[float]]]:
) -> Tuple[List[str], List[float], List[float]]:
"""
Perform regression on a gene using Ordinary Least Squares (OLS).
:param gene: Gene name.
:type gene: str
:param X: Burden score data.
:type X: np.ndarray
:param y: Y phenotype data.
:type y: np.ndarray
:param x_pheno: X phenotype data.
:type x_pheno: np.ndarray
:param use_bias: Flag to include bias term.
:type use_bias: bool
:param use_x_pheno: Flag to include x phenotype data in regression.
:type use_x_pheno: bool
:return: Tuple containing gene name, beta, and p-value.
:rtype: Tuple[List[str], List[float], List[float]]
"""
X = X.reshape(X.shape[0], -1)
if np.all(np.abs(X) < 1e-6):
logger.warning(f"Burden for gene {gene} is 0 for all samples; skipping")
Expand Down Expand Up @@ -561,6 +751,30 @@ def regress_(
use_x_pheno: bool = True,
do_scoretest: bool = True,
) -> pd.DataFrame:
"""
Perform regression on multiple genes.
:param config: Configuration dictionary.
:type config: Dict
:param use_bias: Flag to include bias term when performing OLS regression.
:type use_bias: bool
:param burdens: Burden score data.
:type burdens: np.ndarray
:param y: Y phenotype data.
:type y: np.ndarray
:param gene_indices: Indices of genes.
:type gene_indices: np.ndarray
:param genes: Gene names.
:type genes: pd.Series
:param x_pheno: X phenotype data.
:type x_pheno: np.ndarray
:param use_x_pheno: Flag to include x phenotype data when performing OLS regression, defaults to True.
:type use_x_pheno: bool
:param do_scoretest: Flag to use the scoretest from SEAK, defaults to True.
:type do_scoretest: bool
:return: DataFrame containing regression results on all genes.
:rtype: pd.DataFrame
"""
assert len(gene_indices) == len(genes)

logger.info(f"Computing associations")
Expand Down Expand Up @@ -640,6 +854,33 @@ def regress(
do_scoretest: bool,
sample_file: Optional[str],
):
"""
Perform regression analysis.
:param debug: Flag for debugging.
:type debug: bool
:param chunk: Index of the chunk of data, defaults to 0.
:type chunk: int
:param n_chunks: Number of chunks to split data for processing, defaults to 1.
:type n_chunks: int
:param use_bias: Flag to include bias term when performing OLS regression.
:type use_bias: bool
:param gene_file: Path to the gene file.
:type gene_file: str
:param repeat: Index of the repeat, defaults to 0.
:type repeat: int
:param config_file: Path to the configuration file.
:type config_file: str
:param burden_dir: Path to the directory containing burdens.zarr file.
:type burden_dir: str
:param out_dir: Path to the output directory.
:type out_dir: str
:param do_scoretest: Flag to use the scoretest from SEAK.
:type do_scoretest: bool
:param sample_file: Path to the sample file.
:type sample_file: Optional[str]
:return: Regression results saved to out_dir as "burden_associations_{chunk}.parquet"
"""
logger.info("Loading saved burdens")
y = zarr.open(Path(burden_dir) / "y.zarr")[:]
burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat]
Expand Down Expand Up @@ -708,6 +949,17 @@ def regress(
def combine_regression_results(
result_files: Tuple[str], out_file: str, model_name: Optional[str]
):
"""
Combine multiple regression result files.
:param result_files: List of paths to regression result files.
:type result_files: Tuple[str]
:param out_file: Path to the output file.
:type out_file: str
:param model_name: Name of the regression model.
:type model_name: Optional[str]
:return: Concatenated regression results saved to a parquet file.
"""
logger.info(f"Concatenating results")
results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files])

Expand Down
22 changes: 22 additions & 0 deletions deeprvat/deeprvat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ def update_config(
seed_genes_out: Optional[str],
new_config_file: str,
):
"""
Select seed genes based on baseline results and update the configuration file.
:param old_config_file: Path to the old configuration file.
:type old_config_file: str
:param phenotype: Phenotype to update in the configuration.
:type phenotype: Optional[str]
:param seed_gene_dir: Directory containing seed genes.
:type seed_gene_dir: Optional[str]
:param baseline_results: Paths to baseline result files.
:type baseline_results: Tuple[str]
:param baseline_results_out: Path to save the updated baseline results.
:type baseline_results_out: Optional[str]
:param seed_genes_out: Path to save the seed genes.
:type seed_genes_out: Optional[str]
:param new_config_file: Path to the new configuration file.
:type new_config_file: str
:raises ValueError: If neither --seed-gene-dir nor --baseline-results is specified.
:return: Updated configuration file saved to new_config.yaml.
Selected seed genes saved to seed_genes_out.parquet.
Optionally, save baseline results to a parquet file if baseline_results_out is specified.
"""
if seed_gene_dir is None and len(baseline_results) == 0:
raise ValueError(
"One of --seed-gene-dir and --baseline-results " "must be specified"
Expand Down
Loading

0 comments on commit 7c2b9a6

Please sign in to comment.