diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 5b744096..d968f6fc 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -1060,6 +1060,136 @@ def compute_burdens( source_path.symlink_to(link_burdens) + +@cli.command() +@click.option("--n-chunks", type=int, required=True) +@click.option("--skip-burdens", is_flag=True, default=False) +@click.option("--overwrite", is_flag=True, default=False) +@click.argument("burdens-chunks-dir", type=click.Path(exists=True)) +@click.argument("result-dir", type=click.Path(exists=True)) +def combine_burden_chunks( + n_chunks: int, + skip_burdens: bool, + overwrite: bool, + burdens_chunks_dir: Path, + result_dir: Path, +): + combine_burden_chunks_( + n_chunks=n_chunks, + skip_burdens=skip_burdens, + overwrite=overwrite, + burdens_chunks_dir=Path(burdens_chunks_dir), + result_dir=Path(result_dir), + ) + + +def combine_burden_chunks_( + n_chunks: int, + burdens_chunks_dir: Path, + skip_burdens: bool, + overwrite: bool, + result_dir: Path, +): + compression_level = 1 + burdens_chunks_dir = Path(burdens_chunks_dir) + + burdens, x, y, sample_ids = None, None, None, None + start_id = None + end_id = 0 + + for i, chunk in tqdm( + enumerate(range(0, n_chunks)), desc=f"Merging {n_chunks} chunks" + ): + chunk_dir = burdens_chunks_dir / f"chunk_{chunk}" + + if not skip_burdens: + burdens_chunk = zarr.open((chunk_dir / "burdens.zarr").as_posix(), mode="r") + assert burdens_chunk.attrs["chunk"] == chunk + + y_chunk = zarr.open((chunk_dir / "y.zarr").as_posix(), mode="r") + + x_chunk = zarr.open((chunk_dir / "x.zarr").as_posix(), mode="r") + sample_ids_chunk = zarr.open( + (chunk_dir / "sample_ids.zarr").as_posix(), mode="r" + ) + + total_samples = sample_ids_chunk.attrs["n_total_samples"] + + assert y_chunk.attrs["chunk"] == chunk + assert x_chunk.attrs["chunk"] == chunk + assert sample_ids_chunk.attrs["chunk"] == chunk + + burdens_path = result_dir / "burdens.zarr" + x_path = result_dir / "x.zarr" + y_path = result_dir / "y.zarr" + sample_ids_path = result_dir / "sample_ids.zarr" + + if i == 0: + if not skip_burdens: + burdens_shape = (total_samples,) + burdens_chunk.shape[1:] + + if not overwrite: + assert not burdens_path.exists() + else: + logger.debug("Overwriting existing files") + + logger.debug(f"Opening {burdens_path} in append mode") + burdens = zarr.open( + burdens_path.as_posix(), + mode="a", + shape=burdens_shape, + chunks=(1000, 1000, 1), + dtype=np.float32, + compressor=Blosc(clevel=compression_level), + ) + assert burdens_path.exists() + + logger.debug(f"Opening {y_path} in append mode") + y = zarr.open( + y_path, + mode="a", + shape=(total_samples,) + y_chunk.shape[1:], + chunks=(None, None), + dtype=np.float32, + compressor=Blosc(clevel=compression_level), + ) + logger.debug(f"Opening {x_path} in append mode") + x = zarr.open( + x_path, + mode="a", + shape=(total_samples,) + x_chunk.shape[1:], + chunks=(None, None), + dtype=np.float32, + compressor=Blosc(clevel=compression_level), + ) + logger.debug(f"Opening {sample_ids_path} in append mode") + sample_ids = zarr.open( + sample_ids_path, + mode="a", + shape=(total_samples), + chunks=(None), + dtype=np.float32, + compressor=Blosc(clevel=compression_level), + ) + + assert x_path.exists() + assert y_path.exists() + assert sample_ids_path.exists() + + start_id = end_id + end_id += len(sample_ids_chunk) + + y[start_id:end_id] = y_chunk[:] + x[start_id:end_id] = x_chunk[:] + sample_ids[start_id:end_id] = sample_ids_chunk[:] + + if not skip_burdens: + burdens[start_id:end_id] = burdens_chunk[:] + + logger.info(f"Done merging {n_chunks} chunks.") + + + def regress_on_gene_scoretest( gene: str, burdens: np.ndarray, diff --git a/tests/deeprvat/test_associate.py b/tests/deeprvat/test_associate.py new file mode 100644 index 00000000..aecc6ea8 --- /dev/null +++ b/tests/deeprvat/test_associate.py @@ -0,0 +1,182 @@ +import zipfile +from pathlib import Path + +import numpy as np +import pytest +import zarr + +from deeprvat.deeprvat.associate import combine_burden_chunks_ + +script_dir = Path(__file__).resolve().parent +tests_data_dir = script_dir / "test_data" / "associate" + + +def open_zarr(zarr_path: Path): + zarr_data = zarr.open(zarr_path.as_posix(), mode="r") + return zarr_data + + +def unzip_data(zip_path, out_path): + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(out_path) + + return out_path + + +@pytest.fixture +def chunks_data(request, tmp_path) -> Path: + zipped_chunks_path = Path(request.param) + chunks_unpacked_path = tmp_path / "chunks" + unzip_data(zip_path=zipped_chunks_path, out_path=chunks_unpacked_path) + + yield chunks_unpacked_path + + +@pytest.fixture +def expected_array(request, tmp_path) -> Path: + zipped_expected_path = Path(request.param) + expected_data_unpacked_path = tmp_path / "expected" + unzip_data(zip_path=zipped_expected_path, out_path=expected_data_unpacked_path) + + yield expected_data_unpacked_path + + +@pytest.mark.parametrize( + "n_chunks, skip_burdens, overwrite, chunks_data, expected_array", + [ + ( + n_chunks, + False, + False, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + tests_data_dir / f"combine_burden_chunks/expected/burdens_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + True, + True, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + tests_data_dir / f"combine_burden_chunks/expected/burdens_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + True, + False, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + tests_data_dir / f"combine_burden_chunks/expected/burdens_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + False, + True, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + tests_data_dir / f"combine_burden_chunks/expected/burdens_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ], + indirect=["chunks_data", "expected_array"], +) +def test_combine_burden_chunks_data_same( + n_chunks, + skip_burdens, + overwrite, + tmp_path, + chunks_data, + expected_array, +): + + combine_burden_chunks_( + n_chunks=n_chunks, + burdens_chunks_dir=chunks_data, + skip_burdens=skip_burdens, + overwrite=overwrite, + result_dir=tmp_path, + ) + + zarr_files = ["x.zarr", "y.zarr", "sample_ids.zarr", "burdens.zarr"] + if skip_burdens: + zarr_files.remove("burdens.zarr") + + for zarr_file in zarr_files: + + expected_data = open_zarr(zarr_path=(expected_array / zarr_file)) + written_data = open_zarr(zarr_path=(tmp_path / zarr_file)) + expected_data_arr, written_data_arr = expected_data[:], written_data[:] + assert written_data_arr.dtype == expected_data.dtype + assert expected_data_arr.shape == written_data_arr.shape + assert np.array_equal(expected_data_arr, written_data_arr, equal_nan=True) + + # No more than 10% zeros + nr_zeros = np.count_nonzero(written_data_arr == 0) + zero_percentage = nr_zeros / len(written_data_arr) + assert zero_percentage < 0.1 + + +@pytest.mark.parametrize( + "n_chunks, skip_burdens, overwrite, chunks_data", + [ + ( + n_chunks, + False, + False, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + True, + True, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + True, + False, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ] + + [ + ( + n_chunks, + False, + True, + tests_data_dir / f"combine_burden_chunks/input/chunks_{n_chunks}.zip", + ) + for n_chunks in range(2, 5) + ], + indirect=["chunks_data"], +) +def test_combine_burden_chunks_file_exists( + n_chunks, skip_burdens, overwrite, tmp_path, chunks_data +): + + combine_burden_chunks_( + n_chunks=n_chunks, + burdens_chunks_dir=chunks_data, + skip_burdens=skip_burdens, + overwrite=overwrite, + result_dir=tmp_path, + ) + + if not skip_burdens: + assert (tmp_path / "burdens.zarr").exists() + else: + assert not (tmp_path / "burdens.zarr").exists() + assert (tmp_path / "x.zarr").exists() + assert (tmp_path / "y.zarr").exists() + assert (tmp_path / "sample_ids.zarr").exists()