diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index e8215d21..95487faf 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -301,7 +301,7 @@ def __getitem__(self, index): for pheno, df in samples_by_pheno: idx = df["index"].to_numpy() assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1)) - slice_ = slice(idx[0], idx[-1] + 2) + slice_ = slice(idx[0], idx[-1] + 1) annotations = ( self.data[pheno]["input_tensor"][slice_] @@ -312,7 +312,7 @@ def __getitem__(self, index): result[pheno] = { "indices": self.samples[pheno][slice_], "covariates": self.data[pheno]["covariates"][slice_], - "rare_variant_annotations": annotations, + "rare_variant_annotations": torch.tensor(annotations), "y": self.data[pheno]["y"][slice_], } @@ -341,9 +341,9 @@ def subset_samples(self): pheno_data["y"] = pheno_data["y"][self.samples[pheno]] pheno_data["covariates"] = pheno_data["covariates"][self.samples[pheno]] if self.cache_tensors: - pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:][ + pheno_data["input_tensor"] = pheno_data["input_tensor"][ self.samples[pheno] - ] # TODO: Check this line + ] else: # TODO: Again do this in blocks of 10,000 samples # Create a temporary directory to store the zarr array