From ad7fcb35f03f44f2bc9c4788fe1c54fb8087f16b Mon Sep 17 00:00:00 2001 From: Brian Clarke <9725212+bfclarke@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:50:32 +0200 Subject: [PATCH] Bug/validation samples (#139) * correct the selection of training and validation samples * write training stdout and stdin to cv_split* directories --------- Co-authored-by: Brian Clarke --- deeprvat/deeprvat/train.py | 14 +++++--------- pipelines/training/train.snakefile | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index ed5f6221..e2245ff7 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -439,18 +439,14 @@ def __getitem__(self, index): assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1)) slice_ = slice(idx[0], idx[-1] + 1) - # annotations = ( - # self.data[pheno]["input_tensor"][slice_] - # if self.cache_tensors - # else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] - # ) - annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] + indices = self.samples[pheno][slice_] + annotations = self.data[pheno]["input_tensor_zarr"].oindex[indices, :, :, :] result[pheno] = { - "indices": self.samples[pheno][slice_], - "covariates": self.data[pheno]["covariates"][slice_], + "indices": indices, + "covariates": self.data[pheno]["covariates"][indices], "rare_variant_annotations": torch.tensor(annotations), - "y": self.data[pheno]["y"][slice_], + "y": self.data[pheno]["y"][indices], } return result diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile index 23600b07..a9261067 100644 --- a/pipelines/training/train.snakefile +++ b/pipelines/training/train.snakefile @@ -61,7 +61,7 @@ rule train: mem_mb = 20000, gpus = 1 shell: - f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " + f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results {{params.prefix}}/train_repeat{{{{1}}}}_trial{{{{2}}}}/ " 'deeprvat_train train ' + debug + deterministic +