Skip to content

Commit

Permalink
Write to separate chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
endast committed Apr 24, 2024
1 parent d9ede28 commit 96af15c
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,21 @@ def compute_burdens_(
logger.info("Computing burden scores")
batch_size = data_config["dataloader_config"]["batch_size"]
with torch.no_grad():
burdens_chunk_path = Path(cache_dir) / "chunks" / f"chunk_{chunk}"
burdens_chunk_path.mkdir(exist_ok=True)
logger.info(f"Writing chunks to {burdens_chunk_path}")

for i, batch in tqdm(
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
):


this_burdens, this_y, this_x, this_sampleid = get_burden(
batch, agg_models, device=device, skip_burdens=skip_burdens
)

if i == 0:
if not skip_burdens:
chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:])
Expand All @@ -317,7 +324,7 @@ def compute_burdens_(

if not skip_burdens:
burdens = zarr.open(
Path(cache_dir) / "burdens.zarr",
burdens_chunk_path / f"burdens.zarr",
mode="a",
shape=(n_total_samples,) + this_burdens.shape[1:],
chunks=(1000, 1000, 1),
Expand All @@ -329,29 +336,30 @@ def compute_burdens_(
burdens = None

y = zarr.open(
Path(cache_dir) / "y.zarr",
burdens_chunk_path / f"y.zarr",
mode="a",
shape=(n_total_samples,) + this_y.shape[1:],
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
x = zarr.open(
Path(cache_dir) / "x.zarr",
burdens_chunk_path / f"x.zarr",
mode="a",
shape=(n_total_samples,) + this_x.shape[1:],
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
sample_ids = zarr.open(
Path(cache_dir) / "sample_ids.zarr",
burdens_chunk_path / f"sample_ids.zarr",
mode="a",
shape=(n_total_samples),
chunks=(None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)

start_idx = i * batch_size
end_idx = min(start_idx + batch_size, chunk_end) # read from chunk shape

Expand Down

0 comments on commit 96af15c

Please sign in to comment.