Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Nov 28, 2023
1 parent fe1680d commit 58db4c4
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PearsonCorrTorch,
RSquared,
)
from deeprvat.utils import suggest_hparams
from deeprvat.utils import resolve_path_with_env, suggest_hparams
from numcodecs import Blosc
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
Expand Down Expand Up @@ -226,6 +226,7 @@ def __init__(
split: str = "train",
cache_tensors: bool = False,
temp_dir: Optional[str] = None,
chunksize: int = 1000,
# samples: Optional[Union[slice, np.ndarray]] = None,
# genes: Optional[Union[slice, np.ndarray]] = None
):
Expand All @@ -239,6 +240,7 @@ def __init__(
)

self.cache_tensors = cache_tensors
self.chunksize = chunksize
if self.cache_tensors:
logger.info("Keeping all input tensors in main memory")

Expand All @@ -258,10 +260,11 @@ def __init__(
pheno: pheno_data["samples"][split]
for pheno, pheno_data in self.data.items()
}
if temp_dir is not None:
self.input_tensor_dir = temp_dir
else:
self.input_tensor_dir = TemporaryDirectory(prefix="training_data", dir=".")
temp_path = (Path(resolve_path_with_env(temp_dir)) / "deeprvat_training"
if temp_dir is not None
else Path("deeprvat_training"))
temp_path.mkdir(parents=True, exist_ok=True)
self.input_tensor_dir = TemporaryDirectory(prefix="training_data", dir=str(temp_path))

self.subset_samples()

Expand Down Expand Up @@ -353,7 +356,7 @@ def subset_samples(self):
zarr.save_array(
tensor_path,
pheno_data["input_tensor_zarr"][:][self.samples[pheno]],
chunks=(1000, None, None, None),
chunks=(self.chunksize, None, None, None),
compressor=Blosc(clevel=1),
)
pheno_data["input_tensor_zarr"] = zarr.open(tensor_path)
Expand All @@ -363,7 +366,7 @@ def subset_samples(self):
f"{n_samples_orig} samples kept"
)

def index_input_tensor_zarr(pheno: str, indices: np.ndarray):
def index_input_tensor_zarr(self, pheno: str, indices: np.ndarray):
# IMPORTANT!!! Never call this function after self.subset_samples()

x = self.data[pheno]["input_tensor_zarr"]
Expand All @@ -388,6 +391,8 @@ def __init__(
num_workers: Optional[int] = 0,
pin_memory: bool = False,
cache_tensors: bool = False,
temp_dir: Optional[str] = None,
chunksize: int = 1000,
):
logger.info("Intializing datamodule")

Expand Down Expand Up @@ -451,6 +456,8 @@ def __init__(
"num_workers",
"pin_memory",
"cache_tensors",
"temp_dir",
"chunksize",
)

def upsample(self) -> np.ndarray:
Expand Down Expand Up @@ -488,6 +495,8 @@ def train_dataloader(self):
self.hparams.batch_size,
split="train",
cache_tensors=self.hparams.cache_tensors,
temp_dir=self.hparams.temp_dir,
chunksize=self.hparams.chunksize,
)
return DataLoader(
dataset,
Expand All @@ -507,6 +516,8 @@ def val_dataloader(self):
self.hparams.batch_size,
split="val",
cache_tensors=self.hparams.cache_tensors,
temp_dir=self.hparams.temp_dir,
chunksize=self.hparams.chunksize,
)
return DataLoader(
dataset,
Expand Down Expand Up @@ -565,6 +576,8 @@ def run_bagging(
"upsampling_factor",
"sample_with_replacement",
"cache_tensors",
"temp_dir",
"chunksize",
)
}
dm = MultiphenoBaggingData(
Expand Down

0 comments on commit 58db4c4

Please sign in to comment.