Skip to content

Commit

Permalink
Write chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
endast committed Apr 26, 2024
1 parent 664ac45 commit 88d6dc8
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,15 @@ def compute_burdens_(
chunk_start = chunk * chunk_length
chunk_end = min(n_total_samples, chunk_start + chunk_length)
samples = range(chunk_start, chunk_end)
n_samples = len(samples)
n_samples_chunk = len(samples)
ds = Subset(ds, samples)

logger.info(f"Processing samples in {samples} from {n_total_samples} in total")
else:
n_samples = n_total_samples
logger.info(f"Processing all samples as one chunk.")
n_samples_chunk = n_total_samples
chunk_start = 0
chunk_end = n_samples
chunk_end = n_samples_chunk

dataloader_config = data_config["dataloader_config"]

Expand All @@ -300,11 +301,12 @@ def compute_burdens_(
burdens_chunk_path = Path(cache_dir) / "chunks" / f"chunk_{chunk}"
burdens_chunk_path.mkdir(exist_ok=True, parents=True)
logger.info(f"Writing chunks to {burdens_chunk_path}")
logger.info(f"Writing chunk to {burdens_chunk_path}")

for i, batch in tqdm(
enumerate(dl),
file=sys.stdout,
total=(n_samples // batch_size + (n_samples % batch_size != 0)),
total=(n_samples_chunk // batch_size + (n_samples_chunk % batch_size != 0)),
):

this_burdens, this_y, this_x, this_sampleid = get_burden(
Expand All @@ -313,18 +315,20 @@ def compute_burdens_(

if i == 0:
if not skip_burdens:
chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:])
chunk_y = np.zeros(shape=(n_samples,) + this_y.shape[1:])
chunk_x = np.zeros(shape=(n_samples,) + this_x.shape[1:])
chunk_sampleid = np.zeros(shape=(n_samples))
chunk_burden = np.zeros(
shape=(n_samples_chunk,) + this_burdens.shape[1:]
)
chunk_y = np.zeros(shape=(n_samples_chunk,) + this_y.shape[1:])
chunk_x = np.zeros(shape=(n_samples_chunk,) + this_x.shape[1:])
chunk_sampleid = np.zeros(shape=(n_samples_chunk))

logger.info(f"Batch size: {batch['rare_variant_annotations'].shape}")

if not skip_burdens:
burdens = zarr.open(
burdens_chunk_path / f"burdens.zarr",
mode="a",
shape=this_burdens.shape,
shape=chunk_burden.shape,
chunks=(1000, 1000, 1),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
Expand All @@ -336,23 +340,23 @@ def compute_burdens_(
y = zarr.open(
burdens_chunk_path / f"y.zarr",
mode="a",
shape=this_y.shape,
shape=chunk_y.shape,
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
x = zarr.open(
burdens_chunk_path / f"x.zarr",
mode="a",
shape=this_x.shape,
shape=chunk_x.shape,
chunks=(None, None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
)
sample_ids = zarr.open(
burdens_chunk_path / f"sample_ids.zarr",
mode="a",
shape=len(this_sampleid),
shape=(n_samples_chunk),
chunks=(None),
dtype=np.float32,
compressor=Blosc(clevel=compression_level),
Expand All @@ -377,7 +381,7 @@ def compute_burdens_(
break

if not skip_burdens:
burdens[chunk_start:chunk_end] = chunk_burden
burdens = chunk_burden

y = chunk_y
x = chunk_x
Expand Down Expand Up @@ -952,8 +956,17 @@ def compute_burdens(
skip_burdens=(link_burdens is not None),
)

burden_files = ["burdens.zarr", "x.zarr", "y.zarr", "sample_ids.zarr"]

for burden_file in burden_files:
assert (Path(out_dir) / f"chunk_{chunk}" / burden_file).exists()
else:
logger.info(f"All zarr files exists for chunk {chunk}")

logger.info("Saving computed burdens, corresponding genes, and targets")
np.save(Path(out_dir) / "genes.npy", genes)

# TODO Remove this...
if link_burdens is not None:
source_path = Path(out_dir) / "burdens.zarr"
source_path.unlink(missing_ok=True)
Expand Down

0 comments on commit 88d6dc8

Please sign in to comment.