diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 909090d0..d13e71be 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -82,8 +82,9 @@ def subset_samples( input_tensor: torch.Tensor, covariates: torch.Tensor, y: torch.Tensor, + common_variants: torch.Tensor, min_variant_count: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # First sum over annotations (dim 2) for each variant in each gene. # Then get the number of non-zero values across all variants in all # genes for each sample. @@ -105,10 +106,13 @@ def subset_samples( input_tensor = input_tensor[mask] covariates = covariates[mask] y = y[mask] + + #TODO: Just use nan-mask for y_phenos + common_variants = common_variants[mask] logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept") - return input_tensor, covariates, y + return input_tensor, covariates, y, common_variants def make_dataset_( @@ -209,6 +213,7 @@ def make_dataset_( ) ] rare_batches = [b["rare_variant_annotations"] for b in batches] + max_n_variants = max(r.shape[-1] for r in rare_batches) logging.info("MAXVAR:%s", max_n_variants) logger.info("Building input_tensor, covariates, and y") @@ -220,11 +225,21 @@ def make_dataset_( ) covariates = torch.cat([b["x_phenotypes"] for b in batches]) y = torch.cat([b["y"] for b in batches]) - common_variants = torch.cat([b["common_variants"] for b in batches]) + + common_var_batches = [b["common_variants"] for b in batches] + max_n_common_variants = max(r.shape[-1] for r in common_var_batches) + logging.info("MAX COMMMON VAR:%s", max_n_common_variants) + common_variants = torch.cat( + [ + F.pad(r, (0, max_n_common_variants - r.shape[-1]), value=pad_value) + for r in common_var_batches + ] + ) + #Note: Common variants are only subsetted by missing y values mask logger.info("Subsetting samples by min_variant_count and missing y values") - input_tensor, covariates, y = subset_samples( - input_tensor, covariates, y, config["training"]["min_variant_count"] + input_tensor, covariates, y, common_variants = subset_samples( + input_tensor, covariates, y, common_variants, config["training"]["min_variant_count"] ) return input_tensor, covariates, y, common_variants @@ -294,7 +309,12 @@ def make_dataset( del input_tensor zarr.save_array(covariates_out_file, covariates.numpy()) zarr.save_array(y_out_file, y.numpy()) - zarr.save_array(common_vars_out_file, common_variants.numpy()) + zarr.save_array(common_vars_out_file, + common_variants.numpy(), + chunks=(1000, None), + compressor=Blosc(clevel=compression_level), + ) + del common_variants class MultiphenoDataset(Dataset): @@ -374,6 +394,14 @@ def __init__( ) pheno_data["input_tensor_zarr"] = self.zarr_root[pheno] # pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:] + zarr.copy( + pheno_data["common_variants"], + self.zarr_root, + name=pheno, + chunks=(self.chunksize, None), + compressor=Blosc(clevel=1), + ) + pheno_data["common_variants"] = self.zarr_root[pheno] elif temp_dir is not None: tensor_path = ( Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr" @@ -386,6 +414,17 @@ def __init__( ) pheno_data["input_tensor_zarr"] = zarr.open(tensor_path) + tensor_path_common_vars = ( + Path(self.input_tensor_dir.name) / pheno / "common_variants.zarr" + ) + zarr.copy( + pheno_data["common_variants"], + zarr.DirectoryStore(tensor_path_common_vars), + chunks=(self.chunksize, None), + compressor=Blosc(clevel=1), + ) + pheno_data["common_variants"] = zarr.open(tensor_path_common_vars) + self.min_variant_count = min_variant_count self.samples = { pheno: pheno_data["samples"][split] @@ -442,18 +481,19 @@ def __getitem__(self, index): # else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] # ) annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] + result[pheno] = { "indices": self.samples[pheno][slice_], "covariates": self.data[pheno]["covariates"][slice_], "rare_variant_annotations": torch.tensor(annotations), "y": self.data[pheno]["y"][slice_], - "common_variants": self.data[pheno]["common_variants"][slice_], + "common_variants": torch.tensor(self.data[pheno]["common_variants"][slice_,:]), } return result - # NOTE: This function is broken with current cache_tensors behavior + # NOTE: This function is broken with current cache_tensors behavior. It is also NOT setup for common variants def subset_samples(self): """ Function used to sort out samples which contain real phenotypes with NaN values and @@ -1004,7 +1044,7 @@ def train( data = dict() # pack underlying data into a single dict that can be passed to downstream functions - for pheno, input_tensor_file, covariates_file, y_file in phenotype: + for pheno, input_tensor_file, covariates_file, y_file, common_vars_file in phenotype: data[pheno] = dict() data[pheno]["input_tensor_zarr"] = zarr.open( input_tensor_file, mode="r" @@ -1017,11 +1057,9 @@ def train( data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[ samples ] # TODO: or maybe shouldn't subset here? - data[pheno]["common_variants"] = torch.tensor( - zarr.open(common_vars_file, mode="r")[:] - )[ - samples - ] + data[pheno]["common_variants"] = zarr.open( + common_vars_file, mode="r" + ) if training_gene_file is not None: with open(training_gene_file, "rb") as f: