Skip to content

Commit

Permalink
common variants - zarr creation and subsetting
Browse files Browse the repository at this point in the history
  • Loading branch information
meyerkm committed Dec 6, 2023
1 parent 3d25aa0 commit 83cf2fb
Showing 1 changed file with 52 additions and 14 deletions.
66 changes: 52 additions & 14 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def subset_samples(
input_tensor: torch.Tensor,
covariates: torch.Tensor,
y: torch.Tensor,
common_variants: torch.Tensor,
min_variant_count: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# First sum over annotations (dim 2) for each variant in each gene.
# Then get the number of non-zero values across all variants in all
# genes for each sample.
Expand All @@ -105,10 +106,13 @@ def subset_samples(
input_tensor = input_tensor[mask]
covariates = covariates[mask]
y = y[mask]

#TODO: Just use nan-mask for y_phenos
common_variants = common_variants[mask]

logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept")

return input_tensor, covariates, y
return input_tensor, covariates, y, common_variants


def make_dataset_(
Expand Down Expand Up @@ -209,6 +213,7 @@ def make_dataset_(
)
]
rare_batches = [b["rare_variant_annotations"] for b in batches]

max_n_variants = max(r.shape[-1] for r in rare_batches)
logging.info("MAXVAR:%s", max_n_variants)
logger.info("Building input_tensor, covariates, and y")
Expand All @@ -220,11 +225,21 @@ def make_dataset_(
)
covariates = torch.cat([b["x_phenotypes"] for b in batches])
y = torch.cat([b["y"] for b in batches])
common_variants = torch.cat([b["common_variants"] for b in batches])

common_var_batches = [b["common_variants"] for b in batches]
max_n_common_variants = max(r.shape[-1] for r in common_var_batches)
logging.info("MAX COMMMON VAR:%s", max_n_common_variants)
common_variants = torch.cat(
[
F.pad(r, (0, max_n_common_variants - r.shape[-1]), value=pad_value)
for r in common_var_batches
]
)

#Note: Common variants are only subsetted by missing y values mask
logger.info("Subsetting samples by min_variant_count and missing y values")
input_tensor, covariates, y = subset_samples(
input_tensor, covariates, y, config["training"]["min_variant_count"]
input_tensor, covariates, y, common_variants = subset_samples(
input_tensor, covariates, y, common_variants, config["training"]["min_variant_count"]
)

return input_tensor, covariates, y, common_variants
Expand Down Expand Up @@ -294,7 +309,12 @@ def make_dataset(
del input_tensor
zarr.save_array(covariates_out_file, covariates.numpy())
zarr.save_array(y_out_file, y.numpy())
zarr.save_array(common_vars_out_file, common_variants.numpy())
zarr.save_array(common_vars_out_file,
common_variants.numpy(),
chunks=(1000, None),
compressor=Blosc(clevel=compression_level),
)
del common_variants


class MultiphenoDataset(Dataset):
Expand Down Expand Up @@ -374,6 +394,14 @@ def __init__(
)
pheno_data["input_tensor_zarr"] = self.zarr_root[pheno]
# pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:]
zarr.copy(
pheno_data["common_variants"],
self.zarr_root,
name=pheno,
chunks=(self.chunksize, None),
compressor=Blosc(clevel=1),
)
pheno_data["common_variants"] = self.zarr_root[pheno]
elif temp_dir is not None:
tensor_path = (
Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr"
Expand All @@ -386,6 +414,17 @@ def __init__(
)
pheno_data["input_tensor_zarr"] = zarr.open(tensor_path)

tensor_path_common_vars = (
Path(self.input_tensor_dir.name) / pheno / "common_variants.zarr"
)
zarr.copy(
pheno_data["common_variants"],
zarr.DirectoryStore(tensor_path_common_vars),
chunks=(self.chunksize, None),
compressor=Blosc(clevel=1),
)
pheno_data["common_variants"] = zarr.open(tensor_path_common_vars)

self.min_variant_count = min_variant_count
self.samples = {
pheno: pheno_data["samples"][split]
Expand Down Expand Up @@ -442,18 +481,19 @@ def __getitem__(self, index):
# else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :]
# )
annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :]


result[pheno] = {
"indices": self.samples[pheno][slice_],
"covariates": self.data[pheno]["covariates"][slice_],
"rare_variant_annotations": torch.tensor(annotations),
"y": self.data[pheno]["y"][slice_],
"common_variants": self.data[pheno]["common_variants"][slice_],
"common_variants": torch.tensor(self.data[pheno]["common_variants"][slice_,:]),
}

return result

# NOTE: This function is broken with current cache_tensors behavior
# NOTE: This function is broken with current cache_tensors behavior. It is also NOT setup for common variants
def subset_samples(self):
"""
Function used to sort out samples which contain real phenotypes with NaN values and
Expand Down Expand Up @@ -1004,7 +1044,7 @@ def train(

data = dict()
# pack underlying data into a single dict that can be passed to downstream functions
for pheno, input_tensor_file, covariates_file, y_file in phenotype:
for pheno, input_tensor_file, covariates_file, y_file, common_vars_file in phenotype:
data[pheno] = dict()
data[pheno]["input_tensor_zarr"] = zarr.open(
input_tensor_file, mode="r"
Expand All @@ -1017,11 +1057,9 @@ def train(
data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[
samples
] # TODO: or maybe shouldn't subset here?
data[pheno]["common_variants"] = torch.tensor(
zarr.open(common_vars_file, mode="r")[:]
)[
samples
]
data[pheno]["common_variants"] = zarr.open(
common_vars_file, mode="r"
)

if training_gene_file is not None:
with open(training_gene_file, "rb") as f:
Expand Down

0 comments on commit 83cf2fb

Please sign in to comment.