diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index ed5f6221..e2245ff7 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -439,18 +439,14 @@ def __getitem__(self, index): assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1)) slice_ = slice(idx[0], idx[-1] + 1) - # annotations = ( - # self.data[pheno]["input_tensor"][slice_] - # if self.cache_tensors - # else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] - # ) - annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] + indices = self.samples[pheno][slice_] + annotations = self.data[pheno]["input_tensor_zarr"].oindex[indices, :, :, :] result[pheno] = { - "indices": self.samples[pheno][slice_], - "covariates": self.data[pheno]["covariates"][slice_], + "indices": indices, + "covariates": self.data[pheno]["covariates"][indices], "rare_variant_annotations": torch.tensor(annotations), - "y": self.data[pheno]["y"][slice_], + "y": self.data[pheno]["y"][indices], } return result diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile index 23600b07..a9261067 100644 --- a/pipelines/training/train.snakefile +++ b/pipelines/training/train.snakefile @@ -61,7 +61,7 @@ rule train: mem_mb = 20000, gpus = 1 shell: - f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " + f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results {{params.prefix}}/train_repeat{{{{1}}}}_trial{{{{2}}}}/ " 'deeprvat_train train ' + debug + deterministic +