Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Nov 28, 2023
1 parent d121eff commit 93c8227
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __getitem__(self, index):
for pheno, df in samples_by_pheno:
idx = df["index"].to_numpy()
assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1))
slice_ = slice(idx[0], idx[-1] + 2)
slice_ = slice(idx[0], idx[-1] + 1)

annotations = (
self.data[pheno]["input_tensor"][slice_]
Expand All @@ -312,7 +312,7 @@ def __getitem__(self, index):
result[pheno] = {
"indices": self.samples[pheno][slice_],
"covariates": self.data[pheno]["covariates"][slice_],
"rare_variant_annotations": annotations,
"rare_variant_annotations": torch.tensor(annotations),
"y": self.data[pheno]["y"][slice_],
}

Expand Down Expand Up @@ -341,9 +341,9 @@ def subset_samples(self):
pheno_data["y"] = pheno_data["y"][self.samples[pheno]]
pheno_data["covariates"] = pheno_data["covariates"][self.samples[pheno]]
if self.cache_tensors:
pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:][
pheno_data["input_tensor"] = pheno_data["input_tensor"][
self.samples[pheno]
] # TODO: Check this line
]
else:
# TODO: Again do this in blocks of 10,000 samples
# Create a temporary directory to store the zarr array
Expand Down

0 comments on commit 93c8227

Please sign in to comment.