From 58db4c4b6786c7b0cb77e78de37736a58a7e38ab Mon Sep 17 00:00:00 2001 From: Brian Clarke Date: Tue, 28 Nov 2023 18:48:20 +0100 Subject: [PATCH] bug fixes --- deeprvat/deeprvat/train.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 95487faf..e4bbff69 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -29,7 +29,7 @@ PearsonCorrTorch, RSquared, ) -from deeprvat.utils import suggest_hparams +from deeprvat.utils import resolve_path_with_env, suggest_hparams from numcodecs import Blosc from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -226,6 +226,7 @@ def __init__( split: str = "train", cache_tensors: bool = False, temp_dir: Optional[str] = None, + chunksize: int = 1000, # samples: Optional[Union[slice, np.ndarray]] = None, # genes: Optional[Union[slice, np.ndarray]] = None ): @@ -239,6 +240,7 @@ def __init__( ) self.cache_tensors = cache_tensors + self.chunksize = chunksize if self.cache_tensors: logger.info("Keeping all input tensors in main memory") @@ -258,10 +260,11 @@ def __init__( pheno: pheno_data["samples"][split] for pheno, pheno_data in self.data.items() } - if temp_dir is not None: - self.input_tensor_dir = temp_dir - else: - self.input_tensor_dir = TemporaryDirectory(prefix="training_data", dir=".") + temp_path = (Path(resolve_path_with_env(temp_dir)) / "deeprvat_training" + if temp_dir is not None + else Path("deeprvat_training")) + temp_path.mkdir(parents=True, exist_ok=True) + self.input_tensor_dir = TemporaryDirectory(prefix="training_data", dir=str(temp_path)) self.subset_samples() @@ -353,7 +356,7 @@ def subset_samples(self): zarr.save_array( tensor_path, pheno_data["input_tensor_zarr"][:][self.samples[pheno]], - chunks=(1000, None, None, None), + chunks=(self.chunksize, None, None, None), compressor=Blosc(clevel=1), ) pheno_data["input_tensor_zarr"] = zarr.open(tensor_path) @@ -363,7 +366,7 @@ def subset_samples(self): f"{n_samples_orig} samples kept" ) - def index_input_tensor_zarr(pheno: str, indices: np.ndarray): + def index_input_tensor_zarr(self, pheno: str, indices: np.ndarray): # IMPORTANT!!! Never call this function after self.subset_samples() x = self.data[pheno]["input_tensor_zarr"] @@ -388,6 +391,8 @@ def __init__( num_workers: Optional[int] = 0, pin_memory: bool = False, cache_tensors: bool = False, + temp_dir: Optional[str] = None, + chunksize: int = 1000, ): logger.info("Intializing datamodule") @@ -451,6 +456,8 @@ def __init__( "num_workers", "pin_memory", "cache_tensors", + "temp_dir", + "chunksize", ) def upsample(self) -> np.ndarray: @@ -488,6 +495,8 @@ def train_dataloader(self): self.hparams.batch_size, split="train", cache_tensors=self.hparams.cache_tensors, + temp_dir=self.hparams.temp_dir, + chunksize=self.hparams.chunksize, ) return DataLoader( dataset, @@ -507,6 +516,8 @@ def val_dataloader(self): self.hparams.batch_size, split="val", cache_tensors=self.hparams.cache_tensors, + temp_dir=self.hparams.temp_dir, + chunksize=self.hparams.chunksize, ) return DataLoader( dataset, @@ -565,6 +576,8 @@ def run_bagging( "upsampling_factor", "sample_with_replacement", "cache_tensors", + "temp_dir", + "chunksize", ) } dm = MultiphenoBaggingData(