diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index d13e71be..099525eb 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -364,6 +364,8 @@ def __init__( self.cache_tensors = cache_tensors if self.cache_tensors: self.zarr_root = zarr.group() + self.input_tensor_zroot = self.zarr_root.create_group("input_tensor") + self.common_variants_zroot = self.zarr_root.create_group("common_variants") elif temp_dir is not None: temp_path = Path(resolve_path_with_env(temp_dir)) / "deeprvat_training" temp_path.mkdir(parents=True, exist_ok=True) @@ -387,21 +389,21 @@ def __init__( if self.cache_tensors: zarr.copy( pheno_data["input_tensor_zarr"], - self.zarr_root, + self.input_tensor_zroot, name=pheno, chunks=(self.chunksize, None, None, None), compressor=Blosc(clevel=1), ) - pheno_data["input_tensor_zarr"] = self.zarr_root[pheno] + pheno_data["input_tensor_zarr"] = self.input_tensor_zroot[pheno] # pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:] zarr.copy( pheno_data["common_variants"], - self.zarr_root, + self.common_variants_zroot, name=pheno, chunks=(self.chunksize, None), compressor=Blosc(clevel=1), ) - pheno_data["common_variants"] = self.zarr_root[pheno] + pheno_data["common_variants"] = self.common_variants_zroot[pheno] elif temp_dir is not None: tensor_path = ( Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr"