Skip to content

Commit

Permalink
Feature burden sample ids (#45)
Browse files Browse the repository at this point in the history
* additional sample id output for burden computation
saved as zarr array

* fixup! Format Python code with psf/black pull_request

* fix docstring for readthedocs

* fixup! Format Python code with psf/black pull_request

---------

Co-authored-by: PMBio <PMBio@users.noreply.github.com>
  • Loading branch information
meyerkm and PMBio authored Dec 15, 2023
1 parent e4cc44c commit 67c70ae
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_burden(
agg_models: Dict[str, List[nn.Module]],
device: torch.device = torch.device("cpu"),
skip_burdens=False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute burden scores for rare variants.
Expand All @@ -63,8 +63,8 @@ def get_burden(
:type device: torch.device
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:return: Tuple containing burden scores, target y phenotype values, and x phenotypes.
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
:return: Tuple containing burden scores, target y phenotype values, x phenotypes and sample ids.
:rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Expand All @@ -87,8 +87,9 @@ def get_burden(

y = batch["y"]
x = batch["x_phenotypes"]
sample_ids = batch["sample"]

return burden, y, x
return burden, y, x, sample_ids


def separate_parallel_results(results: List) -> Tuple[List, ...]:
Expand Down Expand Up @@ -196,7 +197,9 @@ def compute_burdens_(
bottleneck: bool = False,
compression_level: int = 1,
skip_burdens: bool = False,
) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]:
) -> Tuple[
np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array, zarr.core.Array
]:
"""
Compute burdens using the PyTorch model for each repeat.
Expand All @@ -223,8 +226,8 @@ def compute_burdens_(
:type compression_level: int
:param skip_burdens: Flag to skip burden computation, defaults to False.
:type skip_burdens: bool
:return: Tuple containing genes, burdens, target y phenotypes, and x phenotypes.
:rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]
:return: Tuple containing genes, burdens, target y phenotypes, x phenotypes and sample ids.
:rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array, zarr.core.Array]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Expand Down Expand Up @@ -280,14 +283,15 @@ def compute_burdens_(
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
):
this_burdens, this_y, this_x = get_burden(
this_burdens, this_y, this_x, this_sampleid = get_burden(
batch, agg_models, device=device, skip_burdens=skip_burdens
)
if i == 0:
if not skip_burdens:
chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:])
chunk_y = np.zeros(shape=(n_samples,) + this_y.shape[1:])
chunk_x = np.zeros(shape=(n_samples,) + this_x.shape[1:])
chunk_sampleid = np.zeros(shape=(n_samples))

logger.info(f"Batch size: {batch['rare_variant_annotations'].shape}")

Expand Down Expand Up @@ -320,6 +324,14 @@ def compute_burdens_(
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
sample_ids = zarr.open(
Path(cache_dir) / "sample_ids.zarr",
mode="a",
shape=(n_total_samples),
chunks=(None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)

start_idx = i * batch_size
end_idx = min(start_idx + batch_size, chunk_end) # read from chunk shape
Expand All @@ -329,6 +341,7 @@ def compute_burdens_(

chunk_y[start_idx:end_idx] = this_y
chunk_x[start_idx:end_idx] = this_x
chunk_sampleid[start_idx:end_idx] = this_sampleid

if debug:
logger.info(
Expand All @@ -343,13 +356,14 @@ def compute_burdens_(

y[chunk_start:chunk_end] = chunk_y
x[chunk_start:chunk_end] = chunk_x
sample_ids[chunk_start:chunk_end] = chunk_sampleid

if torch.cuda.is_available():
logger.info(
"Max GPU memory allocated: " f"{torch.cuda.max_memory_allocated(0)} bytes"
)

return ds_full.rare_embedding.genes, burdens, y, x
return ds_full.rare_embedding.genes, burdens, y, x, sample_ids


def load_one_model(
Expand Down Expand Up @@ -580,8 +594,8 @@ def compute_burdens(
:type checkpoint_files: Tuple[str]
:param out_dir: Path to the output directory.
:type out_dir: str
:return: Corresonding genes, computed burdens, y phenotypes, and x phenotypes are saved in the out_dir.
:rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array]
:return: Corresonding genes, computed burdens, y phenotypes, x phenotypes and sample ids are saved in the out_dir.
:rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array]
.. note::
Checkpoint models all corresponding to the same repeat are averaged for that repeat.
Expand Down Expand Up @@ -614,7 +628,7 @@ def compute_burdens(
else:
agg_models = None

genes, _, _, _ = compute_burdens_(
genes, _, _, _, _ = compute_burdens_(
debug,
data_config,
dataset,
Expand Down

0 comments on commit 67c70ae

Please sign in to comment.