diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..3e133f15 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,13 @@ +# What + +*What does this PR do, and preferably how* + +# Testing +*Testing this PR involves X,Y,Z* + +## Test scenarios +*Test these things* + +1. Instructions +2. ... +3. ... diff --git a/.github/workflows/autoblack_pull_request.yml b/.github/workflows/autoblack_pull_request.yml new file mode 100644 index 00000000..2936e248 --- /dev/null +++ b/.github/workflows/autoblack_pull_request.yml @@ -0,0 +1,35 @@ +# GitHub Action that uses Black to reformat the Python code in an incoming pull request. +# If all Python code in the pull request is complient with Black then this Action does nothing. +# Othewrwise, Black is run and its changes are committed back to the incoming pull request. +# https://github.com/cclauss/autoblack + +name: autoblack_pull_request +on: [ pull_request ] +jobs: + black-code: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ github.head_ref }} + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - run: pip install black + - run: black --check . + - name: If needed, commit black changes to the pull request + if: failure() + run: | + printenv | grep GITHUB + git config --global user.name 'PMBio' + git config --global user.email 'PMBio@users.noreply.github.com' + git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY + git remote -v + git branch + git status + black . + git status + echo ready to commit + git commit -am "fixup! Format Python code with psf/black pull_request" + echo ready to push + git push diff --git a/.github/workflows/docs-tests.yml b/.github/workflows/docs-tests.yml new file mode 100644 index 00000000..c028c44d --- /dev/null +++ b/.github/workflows/docs-tests.yml @@ -0,0 +1,23 @@ +name: "Pull Request Docs Check" +run-name: "Docs Check 📑📝" + +on: +- pull_request + +jobs: + docs-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: ammaraskar/sphinx-action@0.4 + with: + docs-folder: "docs/" + + docs-link-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: ammaraskar/sphinx-action@0.4 + with: + docs-folder: "docs/" + build-command: "make linkcheck" diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index 006baf02..5b3ff8a6 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -8,26 +8,33 @@ jobs: steps: - name: Check out repository code uses: actions/checkout@v3 - - name: Training Association Testing smoke test - uses: snakemake/snakemake-github-action@v1.24.0 + - uses: mamba-org/setup-micromamba@v1.4.3 with: - directory: 'example' - snakefile: 'pipelines/training_association_testing.snakefile' - args: '-j 2 -n' + environment-name: deeprvat-gh-action + environment-file: ${{ github.workspace }}/deeprvat_env_no_gpu.yml + cache-environment: true + cache-downloads: true + - name: Smoketest training_association_testing pipeline + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example \ + --snakefile ${{ github.workspace }}/pipelines/training_association_testing.snakefile --show-failed-logs + shell: micromamba-shell {0} - name: Link pretrained models run: cd ${{ github.workspace }}/example && ln -s ../pretrained_models - - name: Association Testing Pretrained Smoke Test - uses: snakemake/snakemake-github-action@v1.24.0 - with: - directory: 'example' - snakefile: 'pipelines/association_testing_pretrained.snakefile' - args: '-j 2 -n' - - name: Seed Gene Discovery Smoke Test - uses: snakemake/snakemake-github-action@v1.24.0 - with: - directory: 'example' - snakefile: 'pipelines/seed_gene_discovery.snakefile' - args: '-j 2 -n' + shell: bash -el {0} + - name: Smoketest association_testing_pretrained pipeline + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example \ + --snakefile ${{ github.workspace }}/pipelines/association_testing_pretrained.snakefile --show-failed-logs + shell: micromamba-shell {0} + - name: Copy seed gene discovery snakemake config + run: cd ${{ github.workspace }}/example && cp ../deeprvat/seed_gene_discovery/config.yaml . + shell: bash -el {0} + - name: Smoketest seed_gene_discovery pipeline + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example \ + --snakefile ${{ github.workspace }}/pipelines/seed_gene_discovery.snakefile --show-failed-logs + shell: micromamba-shell {0} DeepRVAT-Pipeline-Tests: runs-on: ubuntu-latest @@ -76,15 +83,94 @@ jobs: steps: - name: Check out repository code uses: actions/checkout@v3 - - name: Preprocessing Smoke Test - uses: snakemake/snakemake-github-action@v1.24.0 + - uses: mamba-org/setup-micromamba@v1.4.3 + with: + environment-name: deeprvat-preprocess-gh-action + environment-file: ${{ github.workspace }}/deeprvat_preprocessing_env.yml + cache-environment: true + cache-downloads: true + + - name: Fake fasta data + if: steps.cache-fasta.outputs.cache-hit != 'true' + run: | + cd ${{ github.workspace }}/example/preprocess && touch workdir/reference/GRCh38.primary_assembly.genome.fa + + - name: Run preprocessing pipeline no qc Smoke Test + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example/preprocess \ + --snakefile ${{ github.workspace }}/pipelines/preprocess_no_qc.snakefile \ + --configfile ${{ github.workspace }}/pipelines/config/deeprvat_preprocess_config.yaml --show-failed-logs + shell: micromamba-shell {0} + + + - name: Preprocessing pipeline with qc Smoke Test + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example/preprocess \ + --snakefile ${{ github.workspace }}/pipelines/preprocess_with_qc.snakefile \ + --configfile ${{ github.workspace }}/pipelines/config/deeprvat_preprocess_config.yaml --show-failed-logs + shell: micromamba-shell {0} + + + DeepRVAT-Annotation-Pipeline-Smoke-Tests: + runs-on: ubuntu-latest + steps: + - name: Check out repository code + uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@v1.4.3 + with: + environment-name: deeprvat-preprocess-gh-action + environment-file: ${{ github.workspace }}/deeprvat_preprocessing_env.yml + cache-environment: true + cache-downloads: true + - name: Annotations Smoke Test + run: | + python -m snakemake -n -j 2 --directory ${{ github.workspace }}/example/annotations \ + --snakefile ${{ github.workspace }}/pipelines/annotations.snakefile \ + --configfile ${{ github.workspace }}/pipelines/config/deeprvat_annotation_config.yaml --show-failed-logs + shell: micromamba-shell {0} + + + DeepRVAT-Preprocessing-Pipeline-Tests-No-QC: + runs-on: ubuntu-latest + needs: DeepRVAT-Preprocessing-Pipeline-Smoke-Tests + steps: + - name: Check out repository code + uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@v1.4.3 with: - directory: 'example/preprocess' - snakefile: 'pipelines/preprocess.snakefile' - args: '-j 2 -n --configfile pipelines/config/deeprvat_preprocess_config.yaml' - stagein: 'touch example/preprocess/workdir/reference/GRCh38.primary_assembly.genome.fa' + environment-name: deeprvat-preprocess-gh-action + environment-file: ${{ github.workspace }}/deeprvat_preprocessing_env.yml + cache-environment: true + cache-downloads: true + + - name: Install DeepRVAT + run: pip install -e ${{ github.workspace }} + shell: micromamba-shell {0} + + - name: Cache Fasta file + id: cache-fasta + uses: actions/cache@v3 + with: + path: example/preprocess/workdir/reference + key: ${{ runner.os }}-reference-fasta + + - name: Download and unpack fasta data + if: steps.cache-fasta.outputs.cache-hit != 'true' + run: | + cd ${{ github.workspace }}/example/preprocess && \ + wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/GRCh38.primary_assembly.genome.fa.gz \ + -O workdir/reference/GRCh38.primary_assembly.genome.fa.gz \ + && gzip -d workdir/reference/GRCh38.primary_assembly.genome.fa.gz + + - name: Run preprocessing pipeline + run: | + python -m snakemake -j 2 --directory ${{ github.workspace }}/example/preprocess \ + --snakefile ${{ github.workspace }}/pipelines/preprocess_no_qc.snakefile \ + --configfile ${{ github.workspace }}/pipelines/config/deeprvat_preprocess_config.yaml --show-failed-logs + shell: micromamba-shell {0} + - DeepRVAT-Preprocessing-Pipeline-Tests: + DeepRVAT-Preprocessing-Pipeline-Tests-With-QC: runs-on: ubuntu-latest needs: DeepRVAT-Preprocessing-Pipeline-Smoke-Tests steps: @@ -120,6 +206,6 @@ jobs: - name: Run preprocessing pipeline run: | python -m snakemake -j 2 --directory ${{ github.workspace }}/example/preprocess \ - --snakefile ${{ github.workspace }}/pipelines/preprocess.snakefile \ + --snakefile ${{ github.workspace }}/pipelines/preprocess_with_qc.snakefile \ --configfile ${{ github.workspace }}/pipelines/config/deeprvat_preprocess_config.yaml --show-failed-logs shell: micromamba-shell {0} diff --git a/.github/workflows/test-runner.yml b/.github/workflows/test-runner.yml index e99ad195..4decca6b 100644 --- a/.github/workflows/test-runner.yml +++ b/.github/workflows/test-runner.yml @@ -4,6 +4,25 @@ on: [ push ] jobs: DeepRVAT-Tests-Runner: + runs-on: ubuntu-latest + steps: + - name: Check out repository code + uses: actions/checkout@v3 + - uses: mamba-org/setup-micromamba@v1.4.3 + with: + environment-name: deeprvat-preprocess-gh-action + environment-file: ${{ github.workspace }}/deeprvat_env_no_gpu.yml + cache-environment: true + cache-downloads: true + + - name: Install DeepRVAT + run: pip install -e ${{ github.workspace }} + shell: micromamba-shell {0} + - name: Run pytest deeprvat + run: pytest -v ${{ github.workspace }}/tests/deeprvat + shell: micromamba-shell {0} + + DeepRVAT-Tests-Runner-Preprocessing: runs-on: ubuntu-latest steps: @@ -20,6 +39,6 @@ jobs: run: pip install -e ${{ github.workspace }} shell: micromamba-shell {0} - - name: Run pytest - run: pytest -v ${{ github.workspace }}/tests + - name: Run pytest preprocessing + run: pytest -v ${{ github.workspace }}/tests/preprocessing shell: micromamba-shell {0} diff --git a/.gitignore b/.gitignore index f9347202..c512620e 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +/docs/apidocs/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..f9221ef3 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,20 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + +sphinx: + configuration: docs/conf.py + fail_on_warning: true + +python: + install: + - requirements: docs/requirements.txt diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..2ff1a871 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,68 @@ +cff-version: 1.2.0 +title: DeepRVAT +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Brian + family-names: Clarke + orcid: 'https://orcid.org/0000-0002-6695-286X' + - given-names: Eva + family-names: Holtkamp + orcid: 'https://orcid.org/0000-0002-2129-9908' + - given-names: Hakime + family-names: Öztürk + - given-names: Marcel + family-names: Mück + orcid: 'https://orcid.org/0009-0000-3129-2630' + - given-names: Magnus + family-names: Wahlberg + orcid: 'https://orcid.org/0009-0001-9140-2392' + - given-names: Kayla + family-names: Meyer + orcid: 'https://orcid.org/0009-0003-5063-5266' + - given-names: Felix + family-names: Munzlinger + orcid: 'https://orcid.org/0009-0005-1407-8145' + - given-names: Felix + family-names: Brechtmann + orcid: 'https://orcid.org/0000-0002-0110-152X' + - given-names: Florian Rupert + family-names: Hölzlwimmer + orcid: 'https://orcid.org/0000-0002-5522-2562' + - given-names: Julien + family-names: Gagneur + orcid: 'https://orcid.org/0000-0002-8924-8365' + - given-names: Oliver + family-names: Stegle + orcid: 'https://orcid.org/0000-0002-8818-7193' +identifiers: + - type: doi + value: 10.1101/2023.07.12.548506 +repository-code: 'https://github.com/PMBio/deeprvat' +abstract: >- + Integration of variant annotations using deep set networks + boosts rare variant association genetics. + Rare genetic variants can strongly predispose to disease, + yet accounting for rare variants in genetic analyses is + statistically challenging. While rich variant annotations + hold the promise to enable well-powered rare variant + association tests, methods integrating variant annotations + in a data-driven manner are lacking. Here, we propose + DeepRVAT, a set neural network-based approach to learn + burden scores from rare variants, annotations and + phenotypes. In contrast to existing methods, DeepRVAT + yields a single, trait-agnostic, nonlinear gene impairment + score, enabling both risk prediction and gene discovery in + a unified framework. On 21 quantitative traits and + whole-exome-sequencing data from UK Biobank, DeepRVAT + offers substantial increases in gene discoveries and + improved replication rates in held-out data. Moreover, we + demonstrate that the integrative DeepRVAT gene impairment + score greatly improves detection of individuals at high + genetic risk. We show that pre-trained DeepRVAT scores + generalize across traits, opening up the possibility to + conduct highly computationally efficient rare variant + tests. +license: MIT diff --git a/README.md b/README.md index 7a131954..6ee68736 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ Rare variant association testing using deep learning and data-driven burden scores +[![Documentation Status](https://readthedocs.org/projects/deeprvat/badge/?version=latest)](https://deeprvat.readthedocs.io/en/latest/?badge=latest) ## Installation @@ -10,7 +11,9 @@ Rare variant association testing using deep learning and data-driven burden scor git clone git@github.com:PMBio/deeprvat.git ``` 1. Change directory to the repository: `cd deeprvat` -1. Install the conda environment. We recommend using `mamba`, though you may also replace `mamba` with `conda`: +1. Install the conda environment. We recommend using [mamba](https://mamba.readthedocs.io/en/latest/index.html), though you may also replace `mamba` with `conda` + + *note: [the current deeprvat env does not support cuda when installed with conda](https://github.com/PMBio/deeprvat/issues/16), install using mamba for cuda support.* ``` mamba env create -n deeprvat -f deeprvat_env.yaml ``` @@ -34,14 +37,13 @@ If you are running on an computing cluster, you will need a [profile](https://gi ### Run the preprocessing pipeline on VCF files -Instructions [here](https://github.com/PMBio/deeprvat/blob/main/deeprvat/preprocessing/README.md) +Instructions [here](https://deeprvat.readthedocs.io/en/latest/preprocessing.html) ### Annotate variants -Instructions [here](https://github.com/PMBio/deeprvat/blob/main/deeprvat/annotations/README.md) +Instructions [here](https://deeprvat.readthedocs.io/en/latest/annotations.html) -**NOTE:** The annotation pipeline does not yet provide full output as required by DeepRVAT, but will be continually updated to be more complete. ### Try the full training and association testing pipeline on some example data diff --git a/deeprvat/annotations/__init__.py b/deeprvat/annotations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeprvat/data/dense_gt.py b/deeprvat/data/dense_gt.py index 16670651..c02521d5 100644 --- a/deeprvat/data/dense_gt.py +++ b/deeprvat/data/dense_gt.py @@ -409,7 +409,18 @@ def transform_data(self): self.phenotype_df[col] = rng.permutation( self.phenotype_df[col].to_numpy() ) - + if len(self.y_phenotypes) > 0: + unique_y_val = self.phenotype_df[self.y_phenotypes[0]].unique() + n_unique_y_val = np.count_nonzero(~np.isnan(unique_y_val)) + logger.info(f"unique y values {unique_y_val}") + logger.info(n_unique_y_val) + else: + n_unique_y_val = 0 + if n_unique_y_val == 2: + logger.warning( + "Not applying y transformation because y only has two values and seems to be binary" + ) + self.y_transformation = None if self.y_transformation is not None: if self.y_transformation == "standardize": logger.debug(" Standardizing target phenotype") @@ -425,6 +436,8 @@ def transform_data(self): ) else: raise ValueError(f"Unknown y_transformation: {self.y_transformation}") + else: + logger.warning("Not transforming phenotype") def setup_annotations( self, diff --git a/deeprvat/deeprvat/__init__.py b/deeprvat/deeprvat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index 8e5efa42..f24d3fa3 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -50,7 +50,25 @@ def get_burden( agg_models: Dict[str, List[nn.Module]], device: torch.device = torch.device("cpu"), skip_burdens=False, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute burden scores for rare variants. + + :param batch: A dictionary containing batched data from the DataLoader. + :type batch: Dict + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + :type agg_models: Dict[str, List[nn.Module]] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing burden scores, target y phenotype values, x phenotypes and sample ids. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ with torch.no_grad(): X = batch["rare_variant_annotations"].to(device) burden = [] @@ -69,11 +87,20 @@ def get_burden( y = batch["y"] x = batch["x_phenotypes"] + sample_ids = batch["sample"] - return burden, y, x + return burden, y, x, sample_ids def separate_parallel_results(results: List) -> Tuple[List, ...]: + """ + Separate results from running regression on each gene. + + :param results: List of results obtained from regression analysis. + :type results: List + :return: Tuple of lists containing separated results of regressed_genes, betas, and pvals. + :rtype: Tuple[List, ...] + """ return tuple(map(list, zip(*results))) @@ -88,6 +115,20 @@ def make_dataset_( data_key="data", samples: Optional[List[int]] = None, ) -> Dataset: + """ + Create a dataset based on the configuration. + + :param config: Configuration dictionary. + :type config: Dict + :param debug: Flag for debugging, defaults to False. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param samples: List of sample indices to include in the dataset, defaults to None. + :type samples: List[int] + :return: Loaded instance of the created dataset. + :rtype: Dataset + """ data_config = config[data_key] ds_pickled = data_config.get("pickled", None) @@ -96,12 +137,9 @@ def make_dataset_( with open(ds_pickled, "rb") as f: ds = pickle.load(f) else: - variant_file = data_config.get( - "variant_file", f'{data_config["gt_file"][:-3]}_variants.parquet' - ) ds = DenseGTDataset( data_config["gt_file"], - variant_file=variant_file, + variant_file=data_config["variant_file"], split="", skip_y_na=False, **copy.deepcopy(data_config["dataset_config"]), @@ -125,6 +163,19 @@ def make_dataset_( @click.argument("config-file", type=click.Path(exists=True)) @click.argument("out-file", type=click.Path()) def make_dataset(debug: bool, data_key: str, config_file: str, out_file: str): + """ + Create a dataset based on the provided configuration and save to a pickle file. + + :param debug: Flag for debugging. + :type debug: bool + :param data_key: Key for dataset configuration in the config dictionary, defaults to "data". + :type data_key: str + :param config_file: Path to the configuration file. + :type config_file: str + :param out_file: Path to the output file. + :type out_file: str + :return: Created dataset saved to out_file.pkl + """ with open(config_file) as f: config = yaml.safe_load(f) @@ -146,7 +197,41 @@ def compute_burdens_( bottleneck: bool = False, compression_level: int = 1, skip_burdens: bool = False, -) -> Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array]: +) -> Tuple[ + np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array, zarr.core.Array +]: + """ + Compute burdens using the PyTorch model for each repeat. + + :param debug: Flag for debugging. + :type debug: bool + :param config: Configuration dictionary. + :type config: Dict + :param ds: Torch dataset. + :type ds: torch.utils.data.Dataset + :param cache_dir: Directory to cache zarr files of computed burdens, x phenotypes, and y phenotypes. + :type cache_dir: str + :param agg_models: Loaded PyTorch model(s) for each repeat used for burden computation. + Each key in the dictionary corresponds to a respective repeat. + :type agg_models: Dict[str, List[nn.Module]] + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param device: Device to perform computations on, defaults to "cpu". + :type device: torch.device + :param bottleneck: Flag to enable bottlenecking number of batches, defaults to False. + :type bottleneck: bool + :param compression_level: Blosc compressor compression level for zarr files, defaults to 1. + :type compression_level: int + :param skip_burdens: Flag to skip burden computation, defaults to False. + :type skip_burdens: bool + :return: Tuple containing genes, burdens, target y phenotypes, x phenotypes and sample ids. + :rtype: Tuple[np.ndarray, zarr.core.Array, zarr.core.Array, zarr.core.Array, zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ if not skip_burdens: logger.info("agg_models[*][*].reverse:") pprint( @@ -198,7 +283,7 @@ def compute_burdens_( file=sys.stdout, total=(n_samples // batch_size + (n_samples % batch_size != 0)), ): - this_burdens, this_y, this_x = get_burden( + this_burdens, this_y, this_x, this_sampleid = get_burden( batch, agg_models, device=device, skip_burdens=skip_burdens ) if i == 0: @@ -206,6 +291,7 @@ def compute_burdens_( chunk_burden = np.zeros(shape=(n_samples,) + this_burdens.shape[1:]) chunk_y = np.zeros(shape=(n_samples,) + this_y.shape[1:]) chunk_x = np.zeros(shape=(n_samples,) + this_x.shape[1:]) + chunk_sampleid = np.zeros(shape=(n_samples)) logger.info(f"Batch size: {batch['rare_variant_annotations'].shape}") @@ -238,6 +324,14 @@ def compute_burdens_( dtype=np.float32, compressor=Blosc(clevel=compression_level), ) + sample_ids = zarr.open( + Path(cache_dir) / "sample_ids.zarr", + mode="a", + shape=(n_total_samples), + chunks=(None), + dtype=np.float32, + compressor=Blosc(clevel=compression_level), + ) start_idx = i * batch_size end_idx = min(start_idx + batch_size, chunk_end) # read from chunk shape @@ -247,6 +341,7 @@ def compute_burdens_( chunk_y[start_idx:end_idx] = this_y chunk_x[start_idx:end_idx] = this_x + chunk_sampleid[start_idx:end_idx] = this_sampleid if debug: logger.info( @@ -261,13 +356,14 @@ def compute_burdens_( y[chunk_start:chunk_end] = chunk_y x[chunk_start:chunk_end] = chunk_x + sample_ids[chunk_start:chunk_end] = chunk_sampleid if torch.cuda.is_available(): logger.info( "Max GPU memory allocated: " f"{torch.cuda.max_memory_allocated(0)} bytes" ) - return ds_full.rare_embedding.genes, burdens, y, x + return ds_full.rare_embedding.genes, burdens, y, x, sample_ids def load_one_model( @@ -275,6 +371,18 @@ def load_one_model( checkpoint: str, device: torch.device = torch.device("cpu"), ): + """ + Load a single burden score computation model from a checkpoint file. + + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint: Path to the model checkpoint file. + :type checkpoint: str + :param device: Device to load the model onto, defaults to "cpu". + :type device: torch.device + :return: Loaded PyTorch model for burden score computation. + :rtype: nn.Module + """ model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class.load_from_checkpoint( checkpoint, @@ -293,6 +401,17 @@ def load_one_model( def reverse_models( model_config_file: str, data_config_file: str, checkpoint_files: Tuple[str] ): + """ + Determine if the burden score computation PyTorch model should reverse the output based on PLOF annotations. + + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :return: checkpoint.reverse file is created if the model should reverse the burden score output. + """ with open(model_config_file) as f: model_config = yaml.safe_load(f) @@ -354,6 +473,25 @@ def load_models( checkpoint_files: Tuple[str], device: torch.device = torch.device("cpu"), ) -> Dict[str, List[nn.Module]]: + """ + Load models from multiple checkpoints for multiple repeats. + + :param config: Configuration dictionary. + :type config: Dict + :param checkpoint_files: Paths to checkpoint files. + :type checkpoint_files: Tuple[str] + :param device: Device to load the models onto, defaults to "cpu". + :type device: torch.device + :return: Dictionary of loaded PyTorch models for burden score computation for each repeat. + :rtype: Dict[str, List[nn.Module]] + + :Examples: + + >>> config = {"model": {"type": "MyModel", "config": {"param": "value"}}} + >>> checkpoint_files = ("checkpoint1.pth", "checkpoint2.pth") + >>> load_models(config, checkpoint_files) + {'repeat_0': [MyModel(), MyModel()]} + """ logger.info("Loading models and checkpoints") if all( @@ -433,6 +571,35 @@ def compute_burdens( checkpoint_files: Tuple[str], out_dir: str, ): + """ + Compute burdens based on the provided model and dataset. + + :param debug: Flag for debugging. + :type debug: bool + :param bottleneck: Flag to enable bottlenecking number of batches. + :type bottleneck: bool + :param n_chunks: Number of chunks to split data for processing, defaults to None. + :type n_chunks: Optional[int] + :param chunk: Index of the chunk of data, defaults to None. + :type chunk: Optional[int] + :param dataset_file: Path to the dataset file, i.e., association_dataset.pkl. + :type dataset_file: Optional[str] + :param link_burdens: Path to burden.zarr file to link. + :type link_burdens: Optional[str] + :param data_config_file: Path to the data configuration file. + :type data_config_file: str + :param model_config_file: Path to the model configuration file. + :type model_config_file: str + :param checkpoint_files: Paths to model checkpoint files. + :type checkpoint_files: Tuple[str] + :param out_dir: Path to the output directory. + :type out_dir: str + :return: Corresonding genes, computed burdens, y phenotypes, x phenotypes and sample ids are saved in the out_dir. + :rtype: [np.ndarray], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array], [zarr.core.Array] + + .. note:: + Checkpoint models all corresponding to the same repeat are averaged for that repeat. + """ if len(checkpoint_files) == 0: raise ValueError("At least one checkpoint file must be supplied") @@ -461,7 +628,7 @@ def compute_burdens( else: agg_models = None - genes, _, _, _ = compute_burdens_( + genes, _, _, _, _ = compute_burdens_( debug, data_config, dataset, @@ -482,7 +649,23 @@ def compute_burdens( source_path.symlink_to(link_burdens) -def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score): +def regress_on_gene_scoretest( + gene: str, + burdens: np.ndarray, + model_score, +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using the score test. + + :param gene: Gene name. + :type gene: str + :param burdens: Burden scores associated with the gene. + :type burdens: np.ndarray + :param model_score: Model for score test. + :type model_score: Any + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] + """ burdens = burdens.reshape(burdens.shape[0], -1) logger.info(f"Burdens shape: {burdens.shape}") @@ -499,8 +682,11 @@ def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score): f"gene {gene}, p-value: {pv}, using saddle instead." ) pv = model_score.pv_alt_model(burdens, method="saddle") - - beta = model_score.coef(burdens)["beta"][0, 0] + # beta only for linear models + try: + beta = model_score.coef(burdens)["beta"][0, 0] + except: + beta = None genes_params_pvalues = ([], [], []) genes_params_pvalues[0].append(gene) @@ -517,7 +703,25 @@ def regress_on_gene( x_pheno: np.ndarray, use_bias: bool, use_x_pheno: bool, -) -> Optional[Tuple[List[str], List[float], List[float]]]: +) -> Tuple[List[str], List[float], List[float]]: + """ + Perform regression on a gene using Ordinary Least Squares (OLS). + + :param gene: Gene name. + :type gene: str + :param X: Burden score data. + :type X: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_bias: Flag to include bias term. + :type use_bias: bool + :param use_x_pheno: Flag to include x phenotype data in regression. + :type use_x_pheno: bool + :return: Tuple containing gene name, beta, and p-value. + :rtype: Tuple[List[str], List[float], List[float]] + """ X = X.reshape(X.shape[0], -1) if np.all(np.abs(X) < 1e-6): logger.warning(f"Burden for gene {gene} is 0 for all samples; skipping") @@ -561,6 +765,30 @@ def regress_( use_x_pheno: bool = True, do_scoretest: bool = True, ) -> pd.DataFrame: + """ + Perform regression on multiple genes. + + :param config: Configuration dictionary. + :type config: Dict + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param burdens: Burden score data. + :type burdens: np.ndarray + :param y: Y phenotype data. + :type y: np.ndarray + :param gene_indices: Indices of genes. + :type gene_indices: np.ndarray + :param genes: Gene names. + :type genes: pd.Series + :param x_pheno: X phenotype data. + :type x_pheno: np.ndarray + :param use_x_pheno: Flag to include x phenotype data when performing OLS regression, defaults to True. + :type use_x_pheno: bool + :param do_scoretest: Flag to use the scoretest from SEAK, defaults to True. + :type do_scoretest: bool + :return: DataFrame containing regression results on all genes. + :rtype: pd.DataFrame + """ assert len(gene_indices) == len(genes) logger.info(f"Computing associations") @@ -579,7 +807,12 @@ def regress_( logger.info(f"X shape: {X.shape}, Y shape: {y.shape}") # compute null_model for score test - model_score = scoretest.ScoretestNoK(y, X) + if len(np.unique(y)) == 2: + logger.info("Fitting binary model since only found two distinct y values") + model_score = scoretest.ScoretestLogit(y, X) + else: + logger.info("Fitting linear model") + model_score = scoretest.ScoretestNoK(y, X) genes_betas_pvals = [ regress_on_gene_scoretest(gene, burdens[mask, i], model_score) for i, gene in tqdm( @@ -635,6 +868,33 @@ def regress( do_scoretest: bool, sample_file: Optional[str], ): + """ + Perform regression analysis. + + :param debug: Flag for debugging. + :type debug: bool + :param chunk: Index of the chunk of data, defaults to 0. + :type chunk: int + :param n_chunks: Number of chunks to split data for processing, defaults to 1. + :type n_chunks: int + :param use_bias: Flag to include bias term when performing OLS regression. + :type use_bias: bool + :param gene_file: Path to the gene file. + :type gene_file: str + :param repeat: Index of the repeat, defaults to 0. + :type repeat: int + :param config_file: Path to the configuration file. + :type config_file: str + :param burden_dir: Path to the directory containing burdens.zarr file. + :type burden_dir: str + :param out_dir: Path to the output directory. + :type out_dir: str + :param do_scoretest: Flag to use the scoretest from SEAK. + :type do_scoretest: bool + :param sample_file: Path to the sample file. + :type sample_file: Optional[str] + :return: Regression results saved to out_dir as "burden_associations_{chunk}.parquet" + """ logger.info("Loading saved burdens") y = zarr.open(Path(burden_dir) / "y.zarr")[:] burdens = zarr.open(Path(burden_dir) / "burdens.zarr")[:, :, repeat] @@ -703,6 +963,17 @@ def regress( def combine_regression_results( result_files: Tuple[str], out_file: str, model_name: Optional[str] ): + """ + Combine multiple regression result files. + + :param result_files: List of paths to regression result files. + :type result_files: Tuple[str] + :param out_file: Path to the output file. + :type out_file: str + :param model_name: Name of the regression model. + :type model_name: Optional[str] + :return: Concatenated regression results saved to a parquet file. + """ logger.info(f"Concatenating results") results = pd.concat([pd.read_parquet(f, engine="pyarrow") for f in result_files]) diff --git a/deeprvat/deeprvat/config.py b/deeprvat/deeprvat/config.py index 410f4141..1d4de29d 100644 --- a/deeprvat/deeprvat/config.py +++ b/deeprvat/deeprvat/config.py @@ -41,6 +41,28 @@ def update_config( seed_genes_out: Optional[str], new_config_file: str, ): + """ + Select seed genes based on baseline results and update the configuration file. + + :param old_config_file: Path to the old configuration file. + :type old_config_file: str + :param phenotype: Phenotype to update in the configuration. + :type phenotype: Optional[str] + :param seed_gene_dir: Directory containing seed genes. + :type seed_gene_dir: Optional[str] + :param baseline_results: Paths to baseline result files. + :type baseline_results: Tuple[str] + :param baseline_results_out: Path to save the updated baseline results. + :type baseline_results_out: Optional[str] + :param seed_genes_out: Path to save the seed genes. + :type seed_genes_out: Optional[str] + :param new_config_file: Path to the new configuration file. + :type new_config_file: str + :raises ValueError: If neither --seed-gene-dir nor --baseline-results is specified. + :return: Updated configuration file saved to new_config.yaml. + Selected seed genes saved to seed_genes_out.parquet. + Optionally, save baseline results to a parquet file if baseline_results_out is specified. + """ if seed_gene_dir is None and len(baseline_results) == 0: raise ValueError( "One of --seed-gene-dir and --baseline-results " "must be specified" diff --git a/deeprvat/deeprvat/evaluate.py b/deeprvat/deeprvat/evaluate.py index 4f14a188..f3b3e4ff 100644 --- a/deeprvat/deeprvat/evaluate.py +++ b/deeprvat/deeprvat/evaluate.py @@ -77,7 +77,7 @@ def get_baseline_results( ( r["type"].split("/")[0], r["type"].split("/")[1], - ): f"{r['base']}/{pheno}/{r['type']}/eval/burden_associations_testing.parquet" + ): f"{r['base']}/{pheno}/{r['type']}/eval/burden_associations.parquet" for r in config["baseline_results"] } diff --git a/deeprvat/deeprvat/models.py b/deeprvat/deeprvat/models.py index 633bd63c..7f1739bb 100644 --- a/deeprvat/deeprvat/models.py +++ b/deeprvat/deeprvat/models.py @@ -43,6 +43,11 @@ def get_hparam(module: pl.LightningModule, param: str, default: Any): class BaseModel(pl.LightningModule): + """ + Base class containing functions that will be called by PyTorch Lightning in the + background by default. + """ + def __init__( self, config: dict, @@ -53,6 +58,23 @@ def __init__( stage: str = "train", **kwargs, ): + """ + Initializes BaseModel. + + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param stage: Contains a prefix indicating the dataset the model is operating on. Defaults to "train". (optional) + :type stage: str + :param kwargs: Additional keyword arguments. + """ super().__init__() self.save_hyperparameters(config) self.save_hyperparameters(kwargs) @@ -75,6 +97,10 @@ def __init__( raise ValueError("Unknown objective_mode configuration parameter") def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Function used to setup an optimizer and scheduler by their + parameters which are specified in config + """ optimizer_config = self.hparams["optimizer"] optimizer_class = getattr(torch.optim, optimizer_config["type"]) optimizer = optimizer_class( @@ -100,9 +126,25 @@ def configure_optimizers(self) -> torch.optim.Optimizer: return optimizer def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + """ + Function called by trainer during training and returns the loss used + to update weights and biases. + + :param batch: A dictionary containing the batch data. + :type batch: dict + :param batch_idx: The index of the current batch. + :type batch_idx: int + + :returns: torch.Tensor: The loss value computed to update weights and biases + based on the predictions. + :raises RuntimeError: If NaNs are found in the training loss. + """ + # calls DeepSet.forward() y_pred_by_pheno = self(batch) results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): + # compute mean distance in between ground truth and predicted score. results[name] = torch.mean( torch.stack( [ @@ -112,22 +154,49 @@ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: ) ) self.log(f"{self.hparams.stage}_{name}", results[name]) - + # set loss from which we compute backward passes loss = results[self.hparams.metrics["loss"]] if torch.any(torch.isnan(loss)): raise RuntimeError("NaNs found in training loss") return loss def validation_step(self, batch: dict, batch_idx: int): + """ + During validation, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. + + :param batch: A dictionary containing the validation batch data. + :type batch: dict + :param batch_idx: The index of the current validation batch. + :type batch_idx: int + + :returns: dict: A dictionary containing phenotype predictions ("y_pred_by_pheno") + and corresponding ground truth values ("y_by_pheno"). + """ y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()} return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno} def validation_epoch_end( self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] ): + """ + Evaluate accumulated phenotype predictions at the end of the validation epoch. + + This function takes a list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. It computes + various metrics based on these predictions and logs the results. + + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the validation process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] + + :return: None + :rtype: None + """ y_pred_by_pheno = dict() y_by_pheno = dict() for result in prediction_y: + # create a dict for each phenotype that includes all respective predictions pred = result["y_pred_by_pheno"] for pheno, ys in pred.items(): y_pred_by_pheno[pheno] = torch.cat( @@ -138,14 +207,16 @@ def validation_epoch_end( ys, ] ) - + # create a dict for each phenotype that includes the respective ground truth target = result["y_by_pheno"] for pheno, ys in target.items(): y_by_pheno[pheno] = torch.cat( [y_by_pheno.get(pheno, torch.tensor([], device=self.device)), ys] ) + # create a dict for each phenotype that stores the respective loss results = dict() + # for all metrics we want to evaluate (specified in config) for name, fn in self.metric_fns.items(): results[name] = torch.mean( torch.stack( @@ -156,15 +227,36 @@ def validation_epoch_end( ) ) self.log(f"val_{name}", results[name]) - + # consider all metrics only store the most min/max in self.best_objective + # to determine if progress was made in the last training epoch. self.best_objective = self.objective_operation( self.best_objective, results[self.hparams.metrics["objective"]].item() ) def test_step(self, batch: dict, batch_idx: int): + """ + During testing, we do not compute backward passes, such that we can accumulate + phenotype predictions and evaluate them afterward as a whole. + + :param batch: A dictionary containing the testing batch data. + :type batch: dict + :param batch_idx: The index of the current testing batch. + :type batch_idx: int + + :returns: dict: A dictionary containing phenotype predictions ("y_pred") + and corresponding ground truth values ("y"). + :rtype: dict + """ return {"y_pred": self(batch), "y": batch["y"]} def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]): + """ + Evaluate accumulated phenotype predictions at the end of the testing epoch. + + :param prediction_y: A list of dictionaries containing accumulated phenotype predictions + and corresponding ground truth values obtained during the testing process. + :type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]] + """ y_pred = torch.cat([p["y_pred"] for p in prediction_y]) y = torch.cat([p["y"] for p in prediction_y]) @@ -182,6 +274,15 @@ def configure_callbacks(self): class DeepSetAgg(pl.LightningModule): + """ + class contains the gene impairment module used for burden computation. + + Variants are fed through an embedding network Phi to compute a variant embedding. + The variant embedding is processed by a permutation-invariant aggregation to yield a gene embedding. + Afterward, the second network Rho estimates the final gene impairment score. + All parameters of the gene impairment module are shared across genes and traits. + """ + def __init__( self, n_annotations: int, @@ -196,6 +297,32 @@ def __init__( use_sigmoid: bool = False, reverse: bool = False, ): + """ + Initializes the DeepSetAgg module. + + :param n_annotations: Number of annotations. + :type n_annotations: int + :param phi_layers: Number of layers in Phi. + :type phi_layers: int + :param phi_hidden_dim: Internal dimensionality of linear layers in Phi. + :type phi_hidden_dim: int + :param rho_layers: Number of layers in Rho. + :type rho_layers: int + :param rho_hidden_dim: Internal dimensionality of linear layers in Rho. + :type rho_hidden_dim: int + :param activation: Activation function used; should match its name in torch.nn. + :type activation: str + :param pool: Invariant aggregation function used to aggregate gene variants. Possible values: 'max', 'sum'. + :type pool: str + :param output_dim: Number of burden scores. Defaults to 1. (optional) + :type output_dim: int + :param dropout: Probability by which some parameters are set to 0. (optional) + :type dropout: Optional[float] + :param use_sigmoid: Whether to project burden scores to [0, 1]. Also used as a linear activation function during training. Defaults to False. (optional) + :type use_sigmoid: bool + :param reverse: Whether to reverse the burden score (used during association testing). Defaults to False. (optional) + :type reverse: bool + """ super().__init__() self.output_dim = output_dim @@ -205,6 +332,7 @@ def __init__( self.use_sigmoid = use_sigmoid self.reverse = reverse + # setup of Phi input_dim = n_annotations phi = [] for l in range(phi_layers): @@ -216,10 +344,12 @@ def __init__( input_dim = output_dim self.phi = nn.Sequential(OrderedDict(phi)) + # setup permutation-invariant aggregation function if pool not in ("sum", "max"): raise ValueError(f"Unknown pooling operation {pool}") self.pool = pool + # setup of Rho rho = [] for l in range(rho_layers - 1): output_dim = rho_hidden_dim @@ -231,12 +361,33 @@ def __init__( rho.append( (f"rho_linear_{rho_layers - 1}", nn.Linear(input_dim, self.output_dim)) ) + # No final non-linear activation function to keep the relationship between + # gene impairment scores and phenotypes linear self.rho = nn.Sequential(OrderedDict(rho)) def set_reverse(self, reverse: bool = True): + """ + Reverse burden score during association testing if the model predicts in negative space. + + :param reverse: Indicates whether the 'reverse' attribute should be set to True or False. + Defaults to True. + :type reverse: bool + + Note: + Compare associate.py, reverse_models() for further detail + """ self.reverse = reverse def forward(self, x): + """ + Perform a forward pass through the model. + + :param x: Batched input data + :type x: tensor + + :returns: Burden scores + :rtype: tensor + """ x = self.phi(x.permute((0, 1, 3, 2))) # x.shape = samples x genes x variants x phi_latent if self.pool == "sum": @@ -245,7 +396,7 @@ def forward(self, x): x = torch.max(x, dim=2).values # Now x.shape = samples x genes x phi_latent x = self.rho(x) - # x.shape = samples x genes x rho_latent + # x.shape = samples x genes x 1 if self.reverse: x = -x if self.use_sigmoid: @@ -254,6 +405,13 @@ def forward(self, x): class DeepSet(BaseModel): + """ + Wrapper class for burden computation, that also does phenotype prediction. + It inherits parameters from BaseModel, which is where Pytorch Lightning specific functions + like "training_step" or "validation_epoch_end" can be found. + Those functions are called in background by default. + """ + def __init__( self, config: dict, @@ -266,6 +424,28 @@ def __init__( reverse: bool = False, **kwargs, ): + """ + Initialize the DeepSet model. + + :param config: Containing the content of config.yaml. + :type config: dict + :param n_annotations: Contains the number of annotations used for each phenotype. + :type n_annotations: Dict[str, int] + :param n_covariates: Contains the number of covariates used for each phenotype. + :type n_covariates: Dict[str, int] + :param n_genes: Contains the number of genes used for each phenotype. + :type n_genes: Dict[str, int] + :param phenotypes: Contains the phenotypes used during training. + :type phenotypes: List[str] + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[pl.LightningModule / nn.Module] + :param use_sigmoid: Determines if burden scores should be projected to [0, 1]. Acts as a linear activation + function to mimic association testing during training. + :type use_sigmoid: bool + :param reverse: Determines if the burden score should be reversed (used during association testing). + :type reverse: bool + :param kwargs: Additional keyword arguments. + """ super().__init__( config, n_annotations, n_covariates, n_genes, phenotypes, **kwargs ) @@ -277,6 +457,9 @@ def __init__( pool = get_hparam(self, "pool", "sum") dropout = get_hparam(self, "dropout", None) + # self.agg_model compresses a batch + # from: samples x genes x annotations x variants + # to: samples x genes if agg_model is not None: self.agg_model = agg_model else: @@ -293,7 +476,11 @@ def __init__( reverse=reverse, ) self.agg_model.train(False if self.hparams.stage == "val" else True) + # afterwards genes are concatenated with covariates + # to: samples x (genes + covariates) + # dict of various linear layers used for phenotype prediction. + # Returns can be tested against ground truth data. self.gene_pheno = nn.ModuleDict( { pheno: nn.Linear( @@ -304,6 +491,19 @@ def __init__( ) def forward(self, batch): + """ + Forward pass through the model. + + :param batch: Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict + + :returns: Dictionary containing predicted phenotypes + :rtype: dict + """ result = dict() for pheno, this_batch in batch.items(): x = this_batch["rare_variant_annotations"] @@ -318,7 +518,23 @@ def forward(self, batch): class LinearAgg(pl.LightningModule): + """ + To capture only linear effect, this model can be used as it only uses a single + linear layer without a non-linear activation function. + It still contains the gene impairment module used for burden computation. + """ + def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): + """ + Initialize the LinearAgg model. + + :param n_annotations: Number of annotations. + :type n_annotations: int + :param pool: Pooling method ("sum" or "max") to be used. + :type pool: str + :param output_dim: Dimensionality of the output. Defaults to 1. (optional) + :type output_dim: int + """ super().__init__() self.output_dim = output_dim @@ -328,6 +544,15 @@ def __init__(self, n_annotations: int, pool: str, output_dim: int = 1): self.linear = nn.Linear(n_annotations, self.output_dim) def forward(self, x): + """ + Perform a forward pass through the model. + + :param x: Batched input data + :type x: tensor + + :returns: Burden scores + :rtype: tensor + """ x = self.linear( x.permute((0, 1, 3, 2)) ) # x.shape = samples x genes x variants x output_dim @@ -340,6 +565,12 @@ def forward(self, x): class TwoLayer(BaseModel): + """ + Wrapper class to capture linear effects. Inherits parameters from BaseModel, + which is where Pytorch Lightning specific functions like "training_step" or + "validation_epoch_end" can be found. Those functions are called in background by default. + """ + def __init__( self, config: dict, @@ -349,6 +580,21 @@ def __init__( agg_model: Optional[nn.Module] = None, **kwargs, ): + """ + Initializes the TwoLayer model. + + :param config: Represents the content of config.yaml. + :type config: dict + :param n_annotations: Number of annotations. + :type n_annotations: int + :param n_covariates: Number of covariates. + :type n_covariates: int + :param n_genes: Number of genes. + :type n_genes: int + :param agg_model: Model used for burden computation. If not provided, it will be initialized. (optional) + :type agg_model: Optional[nn.Module] + :param kwargs: Additional keyword arguments. + """ super().__init__(config, n_annotations, n_covariates, n_genes, **kwargs) logger.info("Initializing TwoLayer model with parameters:") @@ -374,6 +620,19 @@ def __init__( self.gene_pheno = nn.Linear(self.hparams.n_covariates + self.hparams.n_genes, 1) def forward(self, batch): + """ + Forward pass through the model. + + :param batch: Dictionary of phenotypes, each containing the following keys: + - indices (tensor): Indices for the underlying dataframe. + - covariates (tensor): Covariates of samples, e.g., age. Content: samples x covariates. + - rare_variant_annotations (tensor): Annotated genomic variants. Content: samples x genes x annotations x variants. + - y (tensor): Actual phenotypes (ground truth data). + :type batch: dict + + :returns: Dictionary containing predicted phenotypes + :rtype: dict + """ # samples x genes x annotations x variants x = batch["rare_variant_annotations"] x = self.agg_model(x).squeeze(dim=2) # samples x genes diff --git a/deeprvat/deeprvat/train.py b/deeprvat/deeprvat/train.py index 8889dfae..ecbdef65 100644 --- a/deeprvat/deeprvat/train.py +++ b/deeprvat/deeprvat/train.py @@ -2,23 +2,34 @@ import gc import itertools import logging -import sys import pickle +import random import shutil +import sys from pathlib import Path from pprint import pformat, pprint -from typing import Dict, Optional, Tuple +from tempfile import TemporaryDirectory +from typing import Dict, Optional, Tuple, Union -import torch.nn.functional as F -import numpy as np import click import math +import numpy as np import optuna import pandas as pd import pytorch_lightning as pl +import torch.nn.functional as F +import deeprvat.deeprvat.models as deeprvat_models import torch import yaml import zarr +from deeprvat.data import DenseGTDataset +from deeprvat.metrics import ( + AveragePrecisionWithLogits, + PearsonCorr, + PearsonCorrTorch, + RSquared, +) +from deeprvat.utils import resolve_path_with_env, suggest_hparams from numcodecs import Blosc from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -27,15 +38,6 @@ from torch.utils.data import DataLoader, Dataset, Subset from tqdm import tqdm -import deeprvat.deeprvat.models as deeprvat_models -from deeprvat.data import DenseGTDataset -from deeprvat.metrics import ( - PearsonCorr, - PearsonCorrTorch, - RSquared, - AveragePrecisionWithLogits, -) -from deeprvat.utils import suggest_hparams logging.basicConfig( format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s", @@ -76,12 +78,69 @@ def cli(): pass +def subset_samples( + input_tensor: torch.Tensor, + covariates: torch.Tensor, + y: torch.Tensor, + min_variant_count: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # First sum over annotations (dim 2) for each variant in each gene. + # Then get the number of non-zero values across all variants in all + # genes for each sample. + n_samples_orig = input_tensor.shape[0] + + # n_variants_per_sample = np.sum( + # np.sum(input_tensor.numpy(), axis=2) != 0, axis=(1, 2) + # ) + # n_variant_mask = n_variants_per_sample >= min_variant_count + n_variant_mask = ( + np.sum(np.any(input_tensor.numpy(), axis=(1, 2)), axis=1) >= min_variant_count + ) + + # Also make sure we don't have NaN values for y + nan_mask = ~y.squeeze().isnan() + mask = n_variant_mask & nan_mask.numpy() + + # Subset all the tensors + input_tensor = input_tensor[mask] + covariates = covariates[mask] + y = y[mask] + + logger.info(f"{input_tensor.shape[0]} / {n_samples_orig} samples kept") + + return input_tensor, covariates, y + + def make_dataset_( - config: Dict, - debug: bool = False, - training_dataset_file: str = None, - pickle_only: bool = False, + debug: bool, + pickle_only: bool, + compression_level: int, + training_dataset_file: Optional[str], + config_file: Union[str, Path], + input_tensor_out_file: str, + covariates_out_file: str, + y_out_file: str, ): + """ + Subfunction of make_dataset() + Convert a dataset file to the sparse format used for training and testing associations + + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param debug: Use a strongly reduced dataframe (optional) + :type debug: bool + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: str + :param pickle_only: If True, only store dataset as pickle file and return None. (optional) + :type pickle_only: bool + + :returns: Tuple containing input_tensor, covariates, and target values. + :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + """ + + with open(config_file) as f: + config = yaml.safe_load(f) + n_phenotypes = config.get("n_phenotypes", None) if n_phenotypes is not None: if "seed_genes" in config: @@ -113,13 +172,10 @@ def make_dataset_( or training_dataset_file is None or not Path(training_dataset_file).is_file() ): - variant_file = config["training_data"].get( - "variant_file", - f'{config["training_data"]["gt_file"][:-3]}_variants.parquet', - ) + # load data into sparse data format ds = DenseGTDataset( gt_file=config["training_data"]["gt_file"], - variant_file=variant_file, + variant_file=config["training_data"]["variant_file"], split="", skip_y_na=True, **config["training_data"]["dataset_config"], @@ -166,13 +222,31 @@ def make_dataset_( input_tensor = torch.cat( [ F.pad(r, (0, max_n_variants - r.shape[-1]), value=pad_value) - for r in tqdm(rare_batches, file=sys.stdout) + for r in rare_batches ] ) covariates = torch.cat([b["x_phenotypes"] for b in batches]) y = torch.cat([b["y"] for b in batches]) - return input_tensor, covariates, y + logger.info("Subsetting samples by min_variant_count and missing y values") + input_tensor, covariates, y = subset_samples( + input_tensor, covariates, y, config["training"]["min_variant_count"] + ) + + if not pickle_only: + logger.info("Saving tensors") + zarr.save_array( + input_tensor_out_file, + input_tensor.numpy(), + chunks=(1000, None, None, None), + compressor=Blosc(clevel=compression_level), + ) + del input_tensor + zarr.save_array(covariates_out_file, covariates.numpy()) + zarr.save_array(y_out_file, y.numpy()) + + # DEBUG + return ds.dataset @cli.command() @@ -194,54 +268,101 @@ def make_dataset( covariates_out_file: str, y_out_file: str, ): - with open(config_file) as f: - config = yaml.safe_load(f) - - input_tensor, covariates, y = make_dataset_( - config, - debug=debug, - training_dataset_file=training_dataset_file, - pickle_only=pickle_only, + """ + Uses function make_dataset_() to convert dataset to sparse format and stores the respective data + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param pickle_only: Flag to indicate whether only to save data using pickle + :type pickle_only: bool + :param compression_level: Level of compression in ZARR to be applied to training data. + :type compression_level: int + :param training_dataset_file: Path to the file in which training data is stored. (optional) + :type training_dataset_file: Optional[str] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param input_tensor_out_file: Path to save the training data to. + :type input_tensor_out_file: str + :param covariates_out_file: Path to save the covariates to. + :type covariates_out_file: str + :param y_out_file: Path to save the ground truth data to. + :type y_out_file: str + + :returns: None + """ + + make_dataset_( + debug, + pickle_only, + compression_level, + training_dataset_file, + config_file, + input_tensor_out_file, + covariates_out_file, + y_out_file, ) - if not pickle_only: - logger.info("Saving tensors") - zarr.save_array( - input_tensor_out_file, - input_tensor.numpy(), - chunks=(1000, None, None, None), - compressor=Blosc(clevel=compression_level), - ) - del input_tensor - zarr.save_array(covariates_out_file, covariates.numpy()) - zarr.save_array(y_out_file, y.numpy()) class MultiphenoDataset(Dataset): + """ + class used to structure the data and present a __getitem__ function to + the dataloader, that will be used to load batches into the model + """ + def __init__( self, # input_tensor: zarr.core.Array, # covariates: zarr.core.Array, # y: zarr.core.Array, data: Dict[str, Dict], - min_variant_count: int, + # min_variant_count: int, batch_size: int, split: str = "train", cache_tensors: bool = False, + temp_dir: Optional[str] = None, + chunksize: int = 1000, # samples: Optional[Union[slice, np.ndarray]] = None, # genes: Optional[Union[slice, np.ndarray]] = None ): - "Initialization" + """ + Initialize the MultiphenoDataset. + + :param data: Underlying dataframe from which data is structured into batches. + :type data: Dict[str, Dict] + :param min_variant_count: Minimum number of variants available for each gene. + :type min_variant_count: int + :param batch_size: Number of samples/individuals available in one batch. + :type batch_size: int + :param split: Contains a prefix indicating the dataset the model operates on. Defaults to "train". (optional) + :type split: str + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. (optional) + :type cache_tensors: bool + """ + super().__init__() - self.data = data + self.data = copy.deepcopy(data) self.phenotypes = self.data.keys() logger.info( f"Initializing MultiphenoDataset with phenotypes:\n{pformat(list(self.phenotypes))}" ) self.cache_tensors = cache_tensors + if self.cache_tensors: + self.zarr_root = zarr.group() + elif temp_dir is not None: + temp_path = Path(resolve_path_with_env(temp_dir)) / "deeprvat_training" + temp_path.mkdir(parents=True, exist_ok=True) + self.input_tensor_dir = TemporaryDirectory( + prefix="training_data", dir=str(temp_path) + ) + # Create root group here - for _, pheno_data in self.data.items(): + self.chunksize = chunksize + if self.cache_tensors: + logger.info("Keeping all input tensors in main memory") + + for pheno, pheno_data in self.data.items(): if pheno_data["y"].shape == (pheno_data["input_tensor_zarr"].shape[0], 1): pheno_data["y"] = pheno_data["y"].squeeze() elif pheno_data["y"].shape != (pheno_data["input_tensor_zarr"].shape[0],): @@ -250,18 +371,40 @@ def __init__( ) if self.cache_tensors: - pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:] + zarr.copy( + pheno_data["input_tensor_zarr"], + self.zarr_root, + name=pheno, + chunks=(self.chunksize, None, None, None), + compressor=Blosc(clevel=1), + ) + pheno_data["input_tensor_zarr"] = self.zarr_root[pheno] + # pheno_data["input_tensor"] = pheno_data["input_tensor_zarr"][:] + elif temp_dir is not None: + tensor_path = ( + Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr" + ) + zarr.copy( + pheno_data["input_tensor_zarr"], + zarr.DirectoryStore(tensor_path), + chunks=(self.chunksize, None, None, None), + compressor=Blosc(clevel=1), + ) + pheno_data["input_tensor_zarr"] = zarr.open(tensor_path) - self.min_variant_count = min_variant_count + # self.min_variant_count = min_variant_count self.samples = { pheno: pheno_data["samples"][split] for pheno, pheno_data in self.data.items() } - self.subset_samples() + + # self.subset_samples() self.total_samples = sum([s.shape[0] for s in self.samples.values()]) self.batch_size = batch_size + # index all samples and categorize them by phenotype, such that we + # get a dataframe repreenting a chain of phenotypes self.sample_order = pd.DataFrame( { "phenotype": itertools.chain( @@ -273,6 +416,7 @@ def __init__( {"phenotype": pd.api.types.CategoricalDtype()} ) self.sample_order = self.sample_order.sample(n=self.total_samples) # shuffle + # phenotype specific index; e.g. 7. element total, 2. element for phenotype "Urate" self.sample_order["index"] = self.sample_order.groupby("phenotype").cumcount() def __len__(self): @@ -289,62 +433,141 @@ def __getitem__(self, index): start_idx = index * self.batch_size end_idx = min(self.total_samples, start_idx + self.batch_size) batch_samples = self.sample_order.iloc[start_idx:end_idx] - samples_by_pheno = batch_samples.groupby("phenotype") + samples_by_pheno = batch_samples.groupby("phenotype", observed=True) result = dict() for pheno, df in samples_by_pheno: + # get phenotype specific sub-index idx = df["index"].to_numpy() + assert np.array_equal(idx, np.arange(idx[0], idx[-1] + 1)) + slice_ = slice(idx[0], idx[-1] + 1) - annotations = ( - self.data[pheno]["input_tensor"][idx] - if self.cache_tensors - else self.data[pheno]["input_tensor_zarr"].oindex[idx, :, :, :] - ) + # annotations = ( + # self.data[pheno]["input_tensor"][slice_] + # if self.cache_tensors + # else self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] + # ) + annotations = self.data[pheno]["input_tensor_zarr"][slice_, :, :, :] result[pheno] = { - "indices": self.samples[pheno][idx], - "covariates": self.data[pheno]["covariates"][idx], - "rare_variant_annotations": annotations, - "y": self.data[pheno]["y"][idx], + "indices": self.samples[pheno][slice_], + "covariates": self.data[pheno]["covariates"][slice_], + "rare_variant_annotations": torch.tensor(annotations), + "y": self.data[pheno]["y"][slice_], } return result - def subset_samples(self): - for pheno, pheno_data in self.data.items(): - # First sum over annotations (dim 2) for each variant in each gene. - # Then get the number of non-zero values across all variants in all - # genes for each sample. - n_samples_orig = self.samples[pheno].shape[0] - - input_tensor = pheno_data["input_tensor_zarr"].oindex[self.samples[pheno]] - n_variants_per_sample = np.sum( - np.sum(input_tensor, axis=2) != 0, axis=(1, 2) - ) - n_variant_mask = n_variants_per_sample >= self.min_variant_count - - nan_mask = ~pheno_data["y"][self.samples[pheno]].isnan() - mask = n_variant_mask & nan_mask.numpy() - self.samples[pheno] = self.samples[pheno][mask] - - logger.info( - f"{pheno}: {self.samples[pheno].shape[0]} / " - f"{n_samples_orig} samples kept" - ) + # # NOTE: This function is broken with current cache_tensors behavior + # def subset_samples(self): + # for pheno, pheno_data in self.data.items(): + # # First sum over annotations (dim 2) for each variant in each gene. + # # Then get the number of non-zero values across all variants in all + # # genes for each sample. + # n_samples_orig = self.samples[pheno].shape[0] + + # # TODO: Compute n_variant_mask one block of 10,000 samples at a time to reduce memory usage + # input_tensor = pheno_data["input_tensor_zarr"].oindex[self.samples[pheno]] + # n_variants_per_sample = np.sum( + # np.sum(input_tensor, axis=2) != 0, axis=(1, 2) + # ) + # n_variant_mask = n_variants_per_sample >= self.min_variant_count + + # # Also make sure we don't have NaN values for y + # nan_mask = ~pheno_data["y"][self.samples[pheno]].isnan() + # mask = n_variant_mask & nan_mask.numpy() + + # # Set the tensor indices to use and subset all the tensors + # self.samples[pheno] = self.samples[pheno][mask] + # pheno_data["y"] = pheno_data["y"][self.samples[pheno]] + # pheno_data["covariates"] = pheno_data["covariates"][self.samples[pheno]] + # if self.cache_tensors: + # pheno_data["input_tensor"] = pheno_data["input_tensor"][ + # self.samples[pheno] + # ] + # else: + # # TODO: Again do this in blocks of 10,000 samples + # # Create a temporary directory to store the zarr array + # tensor_path = ( + # Path(self.input_tensor_dir.name) / pheno / "input_tensor.zarr" + # ) + # zarr.save_array( + # tensor_path, + # pheno_data["input_tensor_zarr"][:][self.samples[pheno]], + # chunks=(self.chunksize, None, None, None), + # compressor=Blosc(clevel=1), + # ) + # pheno_data["input_tensor_zarr"] = zarr.open(tensor_path) + + # logger.info( + # f"{pheno}: {self.samples[pheno].shape[0]} / " + # f"{n_samples_orig} samples kept" + # ) + + # def index_input_tensor_zarr(self, pheno: str, indices: np.ndarray): + # # IMPORTANT!!! Never call this function after self.subset_samples() + + # x = self.data[pheno]["input_tensor_zarr"] + # first_idx = indices[0] + # last_idx = indices[-1] + # slice_ = slice(first_idx, last_idx + 1) + # arange = np.arange(first_idx, last_idx + 1) + # z = x[slice_] + # slice_indices = np.nonzero(np.isin(arange, indices)) + # return z[slice_indices] + + def index_input_tensor_zarr(self, pheno: str, indices: np.ndarray): + # IMPORTANT!!! Never call this function after self.subset_samples() + + x = self.data[pheno]["input_tensor_zarr"] + first_idx = indices[0] + last_idx = indices[-1] + slice_ = slice(first_idx, last_idx + 1) + arange = np.arange(first_idx, last_idx + 1) + z = x[slice_] + slice_indices = np.nonzero(np.isin(arange, indices)) + return z[slice_indices] class MultiphenoBaggingData(pl.LightningDataModule): + """ + Preprocess the underlying dataframe, to then load it into a dataset object + """ + def __init__( self, data: Dict[str, Dict], train_proportion: float, sample_with_replacement: bool = True, - min_variant_count: int = 1, + # min_variant_count: int = 1, upsampling_factor: int = 1, batch_size: Optional[int] = None, num_workers: Optional[int] = 0, + pin_memory: bool = False, cache_tensors: bool = False, + temp_dir: Optional[str] = None, + chunksize: int = 1000, ): + """ + Initialize the MultiphenoBaggingData. + + :param data: Underlying dataframe from which data structured into batches. + :type data: Dict[str, Dict] + :param train_proportion: Percentage by which data is divided into training/validation split. + :type train_proportion: float + :param sample_with_replacement: If True, a sample can be selected multiple times in one epoch. Defaults to True. (optional) + :type sample_with_replacement: bool + :param min_variant_count: Minimum number of variants available for each gene. Defaults to 1. (optional) + :type min_variant_count: int + :param upsampling_factor: Percentual factor by which to upsample data; >= 1. Defaults to 1. (optional) + :type upsampling_factor: int + :param batch_size: Number of samples/individuals available in one batch. Defaults to None. (optional) + :type batch_size: Optional[int] + :param num_workers: Number of workers simultaneously putting data into RAM. Defaults to 0. (optional) + :type num_workers: Optional[int] + :param cache_tensors: Indicates if samples have been pre-loaded or need to be extracted from zarr. Defaults to False. (optional) + :type cache_tensors: bool + """ logger.info("Intializing datamodule") super().__init__() @@ -390,25 +613,35 @@ def __init__( else: n_train_samples = round(n_samples * train_proportion) rng = np.random.default_rng() + # select training samples from the underlying dataframe train_samples = np.sort( rng.choice( samples, size=n_train_samples, replace=sample_with_replacement ) ) + # samples which are not part of train_samples, but in samples + # are validation samples. pheno_data["samples"] = { "train": train_samples, "val": np.setdiff1d(samples, train_samples), } self.save_hyperparameters( - "min_variant_count", + # "min_variant_count", "train_proportion", "batch_size", "num_workers", + "pin_memory", "cache_tensors", + "temp_dir", + "chunksize", ) def upsample(self) -> np.ndarray: + """ + does not work at the moment for multi-phenotype training. Needs some minor changes + to make it work again + """ unique_values = self.y.unique() if unique_values.size() != torch.Size([2]): raise ValueError( @@ -432,35 +665,56 @@ def upsample(self) -> np.ndarray: self.samples = upsampled_indices def train_dataloader(self): + """ + trainning samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating training dataloader " f"with batch size {self.hparams.batch_size}" ) + dataset = MultiphenoDataset( self.data, - self.hparams.min_variant_count, + # self.hparams.min_variant_count, self.hparams.batch_size, split="train", cache_tensors=self.hparams.cache_tensors, + temp_dir=self.hparams.temp_dir, + chunksize=self.hparams.chunksize, ) return DataLoader( - dataset, batch_size=None, num_workers=self.hparams.num_workers + dataset, + batch_size=None, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, ) def val_dataloader(self): + """ + validation samples have been selected, but to structure them and make them load + as a batch they are packed in a dataset class, which is then wrapped by a + dataloading object. + """ logger.info( "Instantiating validation dataloader " f"with batch size {self.hparams.batch_size}" ) dataset = MultiphenoDataset( self.data, - self.hparams.min_variant_count, + # self.hparams.min_variant_count, self.hparams.batch_size, split="val", cache_tensors=self.hparams.cache_tensors, + temp_dir=self.hparams.temp_dir, + chunksize=self.hparams.chunksize, ) return DataLoader( - dataset, batch_size=None, num_workers=self.hparams.num_workers + dataset, + batch_size=None, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, ) @@ -473,10 +727,49 @@ def run_bagging( trial_id: Optional[int] = None, debug: bool = False, ) -> Optional[float]: + """ + Main function called during training. Also used for trial pruning and sampling new parameters in optuna. + + :param config: Dictionary containing configuration parameters, build from YAML file + :type config: Dict + :param data: Dict of phenotypes, each containing a dict storing the underlying data. + :type data: Dict[str, Dict] + :param log_dir: Path to where logs are written. + :type log_dir: str + :param checkpoint_file: Path to where the weights of the trained model should be saved. (optional) + :type checkpoint_file: Optional[str] + :param trial: Optuna object generated from the study. (optional) + :type trial: Optional[optuna.trial.Trial] + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param debug: Use a strongly reduced dataframe + :type debug: bool + + :returns: Optional[float]: computes the lowest scores of all loss metrics and returns their average + :rtype: Optional[float] + """ + + # if hyperparameter optimization is performed (train(); hpopt_file != None) if trial is not None: if trial_id is not None: + # differentiate various repeats in their individual optimization trial.set_user_attr("user_id", trial_id) + # Parameters set in config can be used to indicate hyperparameter optimization. + # Such cases can be spotted by the following exemplary pattern: + # + # phi_hidden_dim: 20 + # hparam: + # type : int + # args: + # - 16 + # - 64 + # kwargs: + # step: 16 + # + # this line should be translated into: + # phi_layers = optuna.suggest_int(name="phi_hidden_dim", low=16, high=64, step=16) + # and afterward replace the respective area in config to set the suggestion. config["model"]["config"] = suggest_hparams(config["model"]["config"], trial) logger.info("Model hyperparameters this trial:") pprint(config["model"]["config"]) @@ -486,6 +779,8 @@ def run_bagging( with open(config_out, "w") as f: yaml.dump(config, f) + # in practice we only train a single bag, as there are + # theoretical reasons to omit bagging w.r.t. association testing n_bags = config["training"]["n_bags"] if not debug else 3 train_proportion = config["training"].get("train_proportion", None) logger.info(f"Training {n_bags} bagged models") @@ -509,12 +804,15 @@ def run_bagging( for k, v in config["training"].items() if k in ( - "min_variant_count", + # "min_variant_count", "upsampling_factor", "sample_with_replacement", "cache_tensors", + "temp_dir", + "chunksize", ) } + # load data into the required formate dm = MultiphenoBaggingData( this_data, train_proportion, @@ -522,6 +820,7 @@ def run_bagging( **config["training"]["dataloader_config"], ) + # setup the model architecture as specified in config model_class = getattr(deeprvat_models, config["model"]["type"]) model = model_class( config=config["model"]["config"], @@ -539,6 +838,8 @@ def run_bagging( objective = "val_" + config["model"]["config"]["metrics"]["objective"] checkpoint_callback = ModelCheckpoint(monitor=objective) callbacks = [checkpoint_callback] + + # to prune underperforming trials we enable a pruning strategy that can be set in config if "early_stopping" in config: callbacks.append( EarlyStopping(monitor=objective, **config["early_stopping"]) @@ -548,14 +849,17 @@ def run_bagging( config["pl_trainer"]["min_epochs"] = 10 config["pl_trainer"]["max_epochs"] = 20 + # initialize trainer, which will call background functionality trainer = pl.Trainer( logger=tb_logger, callbacks=callbacks, **config.get("pl_trainer", {}) ) while True: try: + # actual training of the model trainer.fit(model, dm) except RuntimeError as e: + # if batch_size is choosen to big, it will be reduced until it fits the GPU logging.error(f"Caught RuntimeError: {e}") if str(e).find("CUDA out of memory") != -1: if dm.hparams.batch_size > 4: @@ -654,6 +958,35 @@ def train( log_dir: str, hpopt_file: str, ): + """ + Main function called during training. Also used for trial pruning and sampling new parameters in Optuna. + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param training_gene_file: Path to a pickle file specifying on which genes training should be executed. (optional) + :type training_gene_file: Optional[str] + :param n_trials: Number of trials to be performed by the given setting. + :type n_trials: int + :param trial_id: Current trial in range n_trials. (optional) + :type trial_id: Optional[int] + :param sample_file: Path to a pickle file specifying which samples should be considered during training. (optional) + :type sample_file: Optional[str] + :param phenotype: Array of phenotypes, containing an array of paths where the underlying data is stored: + - str: Phenotype name + - str: Annotated gene variants as zarr file + - str: Covariates each sample as zarr file + - str: Ground truth phenotypes as zarr file + :type phenotype: Tuple[Tuple[str, str, str, str]] + :param config_file: Path to a YAML file, which serves for configuration. + :type config_file: str + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param hpopt_file: Path to where a .db file should be created in which the results of hyperparameter optimization are stored. + :type hpopt_file: str + + :raises ValueError: If no phenotype option is specified. + """ + if len(phenotype) == 0: raise ValueError("At least one --phenotype option must be specified") @@ -681,13 +1014,20 @@ def train( samples = slice(None) data = dict() + # pack underlying data into a single dict that can be passed to downstream functions for pheno, input_tensor_file, covariates_file, y_file in phenotype: data[pheno] = dict() - data[pheno]["input_tensor_zarr"] = zarr.open(input_tensor_file, mode="r") + data[pheno]["input_tensor_zarr"] = zarr.open( + input_tensor_file, mode="r" + ) # TODO: subset here? data[pheno]["covariates"] = torch.tensor( zarr.open(covariates_file, mode="r")[:] - )[samples] - data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[samples] + )[ + samples + ] # TODO: or maybe shouldn't subset here? + data[pheno]["y"] = torch.tensor(zarr.open(y_file, mode="r")[:])[ + samples + ] # TODO: or maybe shouldn't subset here? if training_gene_file is not None: with open(training_gene_file, "rb") as f: @@ -773,6 +1113,22 @@ def train( def best_training_run( debug: bool, log_dir: str, checkpoint_dir: str, hpopt_db: str, config_file_out: str ): + """ + Function to extract the best trial from an Optuna study and handle associated model checkpoints and configurations. + + :param debug: Use a strongly reduced dataframe + :type debug: bool + :param log_dir: Path to where logs are stored. + :type log_dir: str + :param checkpoint_dir: Directory where checkpoints have been stored. + :type checkpoint_dir: str + :param hpopt_db: Path to the database file containing the Optuna study results. + :type hpopt_db: str + :param config_file_out: Path to store a reduced configuration file. + :type config_file_out: str + + :returns: None + """ study = optuna.load_study( study_name=Path(hpopt_db).stem, storage=f"sqlite:///{hpopt_db}" ) @@ -796,6 +1152,7 @@ def best_training_run( link_path.symlink_to(checkpoint.resolve(strict=True)) # Keep track of models marked to be dropped + # respective models are not used for downstream processing checkpoint_dropped = Path(str(checkpoint) + ".dropped") if checkpoint_dropped.is_file(): dropped_link_path = Path(checkpoint_dir) / f"bag_{k}.ckpt.dropped" diff --git a/deeprvat/metrics.py b/deeprvat/metrics.py index 429ddfb3..f7b74a01 100644 --- a/deeprvat/metrics.py +++ b/deeprvat/metrics.py @@ -15,10 +15,24 @@ class RSquared: + """ + Calculates the R-squared (coefficient of determination) between predictions and targets. + """ + def __init__(self): pass def __call__(self, preds: torch.tensor, targets: torch.tensor): + """ + Calculate R-squared value between two tensors. + + :param preds: Tensor containing predicted values. + :type preds: torch.tensor + :param targets: Tensor containing target values. + :type targets: torch.tensor + :return: R-squared value. + :rtype: torch.tensor + """ y_mean = torch.mean(targets) ss_tot = torch.sum(torch.square(targets - y_mean)) ss_res = torch.sum(torch.square(targets - preds)) @@ -26,10 +40,24 @@ def __call__(self, preds: torch.tensor, targets: torch.tensor): class PearsonCorr: + """ + Calculates the Pearson correlation coefficient between burdens and targets. + """ + def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient. + + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: float + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -48,10 +76,24 @@ def __call__(self, burden, y): class PearsonCorrTorch: + """ + Calculates the Pearson correlation coefficient between burdens and targets using PyTorch tensor operations. + """ + def __init__(self): pass def __call__(self, burden, y): + """ + Calculate Pearson correlation coefficient using PyTorch tensor operations. + + :param burden: Tensor containing burden values. + :type burden: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Pearson correlation coefficient. + :rtype: torch.tensor + """ if len(burden.shape) > 1: # was the burden computed for >1 genes corrs = [] for i in range(burden.shape[1]): # number of genes @@ -83,9 +125,23 @@ def calculate_pearsonr(self, x, y): class AveragePrecisionWithLogits: + """ + Calculates the average precision score between logits and targets. + """ + def __init__(self): pass def __call__(self, logits, y): + """ + Calculate average precision score. + + :param logits: Tensor containing logits. + :type logits: torch.tensor + :param y: Tensor containing target values. + :type y: torch.tensor + :return: Average precision score. + :rtype: float + """ y_scores = F.sigmoid(logits.detach()) return average_precision_score(y.detach().cpu().numpy(), y_scores.cpu().numpy()) diff --git a/deeprvat/preprocessing/README.md b/deeprvat/preprocessing/README.md deleted file mode 100644 index d499a6b6..00000000 --- a/deeprvat/preprocessing/README.md +++ /dev/null @@ -1,166 +0,0 @@ -# DeepRVAT Preprocessing pipeline - -The DeepRVAT preprocessing pipeline is based on [snakemake](https://snakemake.readthedocs.io/en/stable/) it uses -[bcftools+samstools](https://www.htslib.org/) and a [python script](preprocess.py) preprocessing.py. - -![DeepRVAT preprocessing pipeline](./preprocess_rulegraph.svg) - -## Output - -The important files that this pipeline produces that are needed in DeepRVAT are: - -- **preprocessed/genotypes.h5** *The main sparse hdf5 file* - -- **norm/variants/variants.parquet** *List of variants i parquet format* - -## Setup environment - -Create the DeepRVAT processing environment - -Clone this repository: - -```shell -git clone git@github.com:PMBio/deeprvat.git -``` - -Change directory to the repository: `cd deeprvat` - -```shell -mamba env create --file deeprvat_preprocessing_env.yml -``` - -Activate the environment - -```shell -mamba activate deeprvat_preprocess -``` - -Install DeepRVAT in the environment - -```shell -pip install -e . -``` - -## Configure preprocessing - -The snakemake preprocessing is configured using a yaml file with the format below. -An example file is included in this repo: [example config](config/deeprvat_preprocess_config.yaml). - -```yaml -# What chromosomes should be processed -included_chromosomes: [ 20,21,22 ] - -# If you need to run a cmd to load bcf and samtools specify it here -bcftools_load_cmd: module load bcftools/1.10.2 && -samtools_load_cmd: module load samtools/1.9 && - -# Path to where you want to write results and intermediate data -working_dir: /workdir -# Path to ukbb data -data_dir: /data - -# These paths are all relative to the data dir -input_vcf_dir_name: vcf -metadata_dir_name: metadata - -# expected to be found in the data_dir / metadata_dir -pvcf_blocks_file: pvcf_blocks.txt - -# These paths are all relative to the working dir -# Here will the finished preprocessed files end up -preprocessed_dir_name: preprocesed -# Path to directory with fasta reference file -reference_dir_name: reference -# Here we will store normalized bcf files -norm_dir_name: norm -# Here we store "sparsified" bcf files -sparse_dir_name: sparse - -# Expected to be found in working_dir/reference_dir -reference_fasta_file: GRCh38_full_analysis_set_plus_decoy_hla.fa - -# The format of the name of the "raw" vcf files -vcf_filename_pattern: ukb23156_c{chr}_b{block}_v1.vcf.gz - -# Number of threads to use in the preprocessing script, separate from snakemake threads -preprocess_threads: 16 - ``` - -The config above would use the following directory structure: - -```shell -parent_directory -|-- data -| |-- metadata -| `-- vcf -`-- workdir - |-- norm - | |-- bcf - | |-- sparse - | `-- variants - |-- preprocesed - |-- qc - | |-- allelic_imbalance - | |-- duplicate_vars - | |-- filtered_samples - | |-- hwe - | |-- indmiss - | | |-- samples - | | |-- sites - | | `-- stats - | |-- read_depth - | `-- varmiss - `-- reference - -``` - -## Running the preprocess pipeline - -### Run the preprocess pipeline with example data - -*The vcf files in the example data folder was generated using [fake-vcf](https://github.com/endast/fake-vcf) (with some -manual editing). -hence does not contain real data.* - -1. cd into the preprocessing example dir - -```shell -cd -cd example/preprocess -``` - -2. Download the fasta file - -```shell -wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/GRCh38.primary_assembly.genome.fa.gz -P workdir/reference -``` - -3. Unpack the fasta file - -```shell -gzip -d workdir/reference/GRCh38.primary_assembly.genome.fa.gz -``` - -4. Run with the example config - -```shell -snakemake -j 1 --snakefile ../../pipelines/preprocess.snakefile --configfile ../../pipelines/config/deeprvat_preprocess_config.yaml -``` - -5. Enjoy the preprocessed data 🎉 - -```shell -ls -l workdir/preprocesed -total 48 --rw-r--r-- 1 user staff 6404 Aug 2 14:06 genotypes.h5 --rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr21.h5 --rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr22.h5 -``` - -### Run on your own data - -After configuration and activating the environment run the pipeline using snakemake: - -```shell -snakemake -j --configfile config/deeprvat_preprocess_config.yaml -s preprocess.snakefile -``` diff --git a/deeprvat/preprocessing/__init__.py b/deeprvat/preprocessing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeprvat/preprocessing/preprocess.py b/deeprvat/preprocessing/preprocess.py index 729e0789..46eba386 100644 --- a/deeprvat/preprocessing/preprocess.py +++ b/deeprvat/preprocessing/preprocess.py @@ -1,6 +1,5 @@ import gc import logging -import os import sys import time from pathlib import Path @@ -49,10 +48,10 @@ def process_sparse_gt_file( samples: List[str], calls_to_exclude: pd.DataFrame = None, ) -> Tuple[List[np.ndarray], List[np.ndarray]]: - sparse_gt = pd.read_csv( + sparse_gt = pd.read_table( file, names=["chrom", "pos", "ref", "alt", "sample", "gt"], - sep="\t", + engine="pyarrow", index_col=None, ) sparse_gt = sparse_gt[sparse_gt["sample"].isin(samples)] @@ -146,6 +145,18 @@ def add_variant_ids(variant_file: str, out_file: str, duplicates_file: str): ) +def get_file_chromosome(file, col_names, chrom_field="chrom"): + chrom_df = pd.read_csv( + file, names=col_names, sep="\t", index_col=None, nrows=1, usecols=[chrom_field] + ) + + chrom = None + if not chrom_df.empty: + chrom = chrom_df[chrom_field][0] + + return chrom + + @cli.command() @click.option("--exclude-variants", type=click.Path(exists=True), multiple=True) @click.option("--exclude-samples", type=click.Path(exists=True)) @@ -171,19 +182,24 @@ def process_sparse_gt( ): logging.info("Reading variants...") start_time = time.time() - variants = pd.read_csv(variant_file, sep="\t") + + variants = pd.read_table(variant_file, engine="pyarrow") + + # Filter all variants based on chromosome if chromosomes is not None: chromosomes = [f"chr{chrom}" for chrom in chromosomes.split(",")] variants = variants[variants["chrom"].isin(chromosomes)] total_variants = len(variants) + if len(exclude_variants) > 0: variant_exclusion_files = [ - v for directory in exclude_variants for v in Path(directory).glob("*.tsv*") + v for directory in exclude_variants for v in Path(directory).rglob("*.tsv*") ] + exclusion_file_cols = ["chrom", "pos", "ref", "alt"] variants_to_exclude = pd.concat( [ - pd.read_csv(v, sep="\t", names=["chrom", "pos", "ref", "alt"]) + pd.read_csv(v, sep="\t", names=exclusion_file_cols) for v in tqdm(variant_exclusion_files) ], ignore_index=True, @@ -212,7 +228,7 @@ def process_sparse_gt( if exclude_samples is not None: total_samples = len(samples) - if sample_exclusion_files := list(Path(exclude_samples).glob("*.csv")): + if sample_exclusion_files := list(Path(exclude_samples).rglob("*.csv")): samples_to_exclude = set( pd.concat( [ @@ -227,8 +243,7 @@ def process_sparse_gt( else: logging.info(f"Found no samples to exclude in {exclude_samples}") - # Assumes only numeric sample names - samples = sorted([s for s in samples if int(s) > 0]) + samples = list(samples) logging.info("Processing sparse GT files by chromosome") total_calls_dropped = 0 @@ -237,39 +252,51 @@ def process_sparse_gt( for chrom in tqdm(variant_groups.groups.keys()): logging.info(f"Processing chromosome {chrom}") logging.info("Reading in filtered calls") + + exclude_calls_file_cols = ["chrom", "pos", "ref", "alt", "sample"] + + calls_to_exclude = pd.DataFrame(columns=exclude_calls_file_cols) + if exclude_calls is not None: - chrom_dir = os.path.join(exclude_calls, chrom) - exclude_calls_chrom = Path(chrom_dir).glob("*.tsv*") - - calls_to_exclude = pd.concat( - [ - pd.read_csv( - c, - names=["chrom", "pos", "ref", "alt", "sample"], - sep="\t", - index_col=None, - ) - for c in tqdm(exclude_calls_chrom) - ], - ignore_index=True, - ) + exclude_calls_chrom = Path(exclude_calls).rglob("*.tsv*") + exclude_calls_chrom_files = [] + for exclude_call_file in exclude_calls_chrom: + file_chrom = get_file_chromosome( + exclude_call_file, col_names=exclude_calls_file_cols + ) - calls_dropped = len(calls_to_exclude) - total_calls_dropped += calls_dropped - logging.info(f"Dropped {calls_dropped} calls") - else: - calls_to_exclude = pd.DataFrame( - columns=["chrom", "pos", "ref", "alt", "sample"] - ) + if file_chrom == chrom: + exclude_calls_chrom_files.append(exclude_call_file) - logging.info("Processing sparse GT files") + if exclude_calls_chrom_files: + calls_to_exclude = pd.concat( + [ + pd.read_csv( + c, + names=exclude_calls_file_cols, + sep="\t", + index_col=None, + ) + for c in tqdm(exclude_calls_chrom_files) + ], + ignore_index=True, + ) + calls_dropped = len(calls_to_exclude) + total_calls_dropped += calls_dropped + logging.info(f"Dropped {calls_dropped} calls") - chrom_dir = os.path.join(sparse_gt, chrom) - logging.info(f"chrom dir is {chrom_dir}") + logging.info("Processing sparse GT files") variants_chrom = variant_groups.get_group(chrom) - sparse_gt_chrom = list(Path(chrom_dir).glob("*.tsv*")) + sparse_file_cols = ["chrom", "pos", "ref", "alt", "sample", "gt"] + + sparse_gt_chrom = [ + f + for f in Path(sparse_gt).rglob("*.tsv*") + if get_file_chromosome(f, col_names=sparse_file_cols) == chrom + ] + logging.info(f"sparse gt chrom(es) are: {sparse_gt_chrom}") processed = Parallel(n_jobs=threads, verbose=50)( diff --git a/deeprvat/seed_gene_discovery/__init__.py b/deeprvat/seed_gene_discovery/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deeprvat/seed_gene_discovery/config.yaml b/deeprvat/seed_gene_discovery/config.yaml index 444d0635..f4a5eba0 100644 --- a/deeprvat/seed_gene_discovery/config.yaml +++ b/deeprvat/seed_gene_discovery/config.yaml @@ -20,7 +20,20 @@ phenotypes: # - Platelet_crit # - Platelet_distribution_width # - Red_blood_cell_erythrocyte_count - +# - Body_mass_index_BMI +# - Glucose +# - Vitamin_D +# - Albumin +# - Total_protein +# - Cystatin_C +# - Gamma_glutamyltransferase +# - Alkaline_phosphatase +# - Creatinine +# - Whole_body_fat_free_mass +# - Forced_expiratory_volume_in_1_second_FEV1 +# - Glycated_haemoglobin_HbA1c +# - WHR_Body_mass_index_BMI_corrected + variant_types: - missense - plof @@ -42,7 +55,7 @@ test_config: neglect_homozygous: False collapse_method: sum #collapsing method for burde var_weight_function: beta_maf - + min_mac: 10 variant_file: variants.parquet data: @@ -99,3 +112,4 @@ data: num_workers: 10 #batch_size: 20 + diff --git a/deeprvat/seed_gene_discovery/seed_gene_discovery.py b/deeprvat/seed_gene_discovery/seed_gene_discovery.py index 28d3e160..bd9c6781 100644 --- a/deeprvat/seed_gene_discovery/seed_gene_discovery.py +++ b/deeprvat/seed_gene_discovery/seed_gene_discovery.py @@ -7,6 +7,8 @@ import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +import copy + import click import numpy as np @@ -18,7 +20,7 @@ from tqdm import tqdm from deeprvat.data import DenseGTDataset -from seak.scoretest import ScoretestNoK +from seak import scoretest logging.basicConfig( format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s", @@ -38,6 +40,14 @@ def replace_in_array(arr, old_val, new_val): return np.where(arr == old_val, new_val, arr) +def get_caf(G): + # get the cumulative allele frequency + ac = G.sum(axis=0) # allele count of each variant + af = ac / (G.shape[0] * 2) # allele frequency of each variant + caf = af.sum() + return caf + + # return mask def save_burdens(GW_list, GW_full_list, split, chunk, out_dir): burdens_path = Path(f"{out_dir}/burdens") @@ -178,7 +188,11 @@ def get_anno( def call_score(GV, null_model_score, pval_dict, test_type): # score test # p-value for the score-test + start_time = time.time() pv = null_model_score.pv_alt_model(GV) + end_time = time.time() + time_diff = end_time - start_time + pval_dict["time"] = time_diff logger.info(f"p-value: {pv}") if pv < 0.0: logger.warning( @@ -195,10 +209,15 @@ def call_score(GV, null_model_score, pval_dict, test_type): if pv < 1e-3 and test_type == "burden": logger.info("Computing regression coefficient") # if gene is quite significant get the regression coefficient + SE - beta = null_model_score.coef(GV) - logger.info(f"Regression coefficient: {beta}") - pval_dict["beta"] = beta["beta"][0, 0] - pval_dict["betaSd"] = np.sqrt(beta["var_beta"][0, 0]) + # only works for quantitative traits + try: + beta = null_model_score.coef(GV) + logger.info(f"Regression coefficient: {beta}") + pval_dict["beta"] = beta["beta"][0, 0] + pval_dict["betaSd"] = np.sqrt(beta["var_beta"][0, 0]) + except: + pval_dict["beta"] = None + pval_dict["betaSd"] = None return pval_dict @@ -207,13 +226,14 @@ def test_gene( G_full: spmatrix, gene: int, grouped_annotations: pd.DataFrame, - dataset: DenseGTDataset, + Y, weight_cols: List[str], - null_model_score: ScoretestNoK, + null_model_score: scoretest.ScoretestNoK, test_config: Dict, var_type, test_type, maf_col, + min_mac, ) -> Dict[str, Any]: # Find variants present in gene # Convert sparse genotype to CSC @@ -232,11 +252,16 @@ def test_gene( # GET expected allele count (EAC) as in Karczewski et al. 2022/Genebass vars_per_sample = np.sum(G, axis=1) samples_with_variant = vars_per_sample[vars_per_sample > 0].shape[0] - EAC = np.sum(vars_per_sample) + if len(np.unique(Y)) == 2: + n_cases = (Y > 0).sum() + else: + n_cases = Y.shape[0] + EAC = get_caf(G) * n_cases pval_dict = {} pval_dict["EAC"] = EAC + pval_dict["n_cases"] = n_cases pval_dict["gene"] = gene pval_dict["pval"] = np.nan pval_dict["EAC_filtered"] = np.nan @@ -247,11 +272,11 @@ def test_gene( pval_dict["time"] = np.nan var_weight_function = test_config.get("var_weight_function", "sift_polyphen") + max_n_markers = test_config.get("max_n_markers", 5000) + # skips genes with more than max_n_markers qualifying variants logger.info(f"Using function {var_weight_function} for variant weighting") - # keep backwards compatibility - ( weights, _, @@ -272,12 +297,12 @@ def test_gene( f"Number of variants after thresholding using threshold {variant_weight_th}: {len(pos)}" ) pval_dict["n_QV"] = len(pos) - - if len(pos) > 0: + pval_dict["markers_after_mac_collapsing"] = len(pos) + if (len(pos) > 0) & (len(pos) < max_n_markers): G_f = G[:, pos] - EAC_filtered = np.sum(np.sum(G_f, axis=1)) + EAC_filtered = EAC = get_caf(G_f) * n_cases pval_dict["EAC_filtered"] = EAC_filtered - + MAC = G_f.sum(axis=0) count = G_f[G_f == 2].shape[0] # confirm that variants we include are rare variants @@ -303,11 +328,28 @@ def test_gene( pval_dict["n_cluster"] = GW.shape[1] ### COLLAPSE kernel if doing burden test - + collapse_ultra_rare = True if test_type == "skat": logger.info("Running Skat test") - GW = GW - + if collapse_ultra_rare: + logger.info(f"Max Collapsing variants with MAC <= {min_mac}") + MAC_mask = MAC <= min_mac + if MAC_mask.sum() > 0: + logger.info(f"Number of collapsed positions: {MAC_mask.sum()}") + GW_collapse = copy.deepcopy(GW) + GW_collapse = GW_collapse[:, MAC_mask].max(axis=1).reshape(-1, 1) + GW = GW[:, ~MAC_mask] + GW = np.hstack((GW_collapse, GW)) + logger.info(f"GW shape {GW.shape}") + else: + logger.info( + f"No ultra rare variants to collapse ({MAC_mask.sum()})" + ) + GW = GW + else: + GW = GW + + pval_dict["markers_after_mac_collapsing"] = GW.shape[1] if test_type == "burden": collapse_method = test_config.get("collapse_method", "binary") logger.info(f"Running burden test with collapsing method {collapse_method}") @@ -335,14 +377,18 @@ def run_association_( ) -> pd.DataFrame: # initialize the null models # ScoretestNoK automatically adds a bias column if not present - null_model_score = ScoretestNoK(Y, X) + if len(np.unique(Y)) == 2: + print("Fitting binary model since only found two distinct y values") + null_model_score = scoretest.ScoretestLogit(Y, X) + else: + null_model_score = scoretest.ScoretestNoK(Y, X) stats = [] GW_list = {} GW_full_list = {} time_list_inner = {} weight_cols = config.get("weight_cols", []) logger.info(f"Testing with this config: {config['test_config']}") - + min_mac = config["test_config"].get("min_mac", 0) # Get column with minor allele frequency annotations = config["data"]["dataset_config"]["annotations"] maf_col = [ @@ -360,13 +406,14 @@ def run_association_( G_full, gene, grouped_annotations, - dataset, + Y, weight_cols, null_model_score, config["test_config"], var_type, test_type, maf_col, + min_mac, ) if persist_burdens: GW_list[gene] = GW @@ -421,7 +468,7 @@ def update_config( simulated_phenotype_file: str, variant_type: Optional[str], rare_maf: Optional[float], - maf_column: Optional[str], + maf_column: str, new_config_file: str, ): with open(old_config_file) as f: @@ -431,6 +478,7 @@ def update_config( config["data"]["dataset_config"][ "sim_phenotype_file" ] = simulated_phenotype_file + logger.info(f"Reading MAF column from column {maf_column}") if phenotype is not None: config["data"]["dataset_config"]["y_phenotypes"] = [phenotype] @@ -645,9 +693,10 @@ def run_association( exploded_annotations = ( dataset.annotation_df.query("id in @all_variants") .explode("gene_ids") + .reset_index() .drop_duplicates() - ) # row can be duplicated if a variant is assigned to a gene multiple times - + .set_index("id") + ) grouped_annotations = exploded_annotations.groupby("gene_ids") gene_ids = pd.read_parquet(dataset.gene_file, columns=["id"])["id"].to_list() gene_ids = list( diff --git a/deeprvat/utils.py b/deeprvat/utils.py index ef9bdf99..3ecad145 100644 --- a/deeprvat/utils.py +++ b/deeprvat/utils.py @@ -24,6 +24,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply False Discovery Rate (FDR) correction to p-values in a DataFrame. + + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ group = group.copy() rejected, pval_corrected = fdrcorrection(group["pval"], alpha=alpha) @@ -33,6 +43,16 @@ def fdrcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: + """ + Apply Bonferroni correction to p-values in a DataFrame. + + :param group: DataFrame containing a "pval" column. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ group = group.copy() pval_corrected = group["pval"] * len(group) @@ -42,6 +62,18 @@ def bfcorrect_df(group: pd.DataFrame, alpha: float) -> pd.DataFrame: def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "FDR"): + """ + Apply p-value correction to a DataFrame. + + :param group: DataFrame containing a column named "pval" with p-values to correct. + :type group: pd.DataFrame + :param alpha: Significance level. + :type alpha: float + :param correction_type: Type of p-value correction. Options are 'FDR' (default) and 'Bonferroni'. + :type correction_type: str + :return: Original DataFrame with additional columns "significant" and "pval_corrected". + :rtype: pd.DataFrame + """ if correction_type == "FDR": corrected = fdrcorrect_df(group, alpha) elif correction_type == "Bonferroni": @@ -56,7 +88,21 @@ def pval_correction(group: pd.DataFrame, alpha: float, correction_type: str = "F return corrected -def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = ""): +def suggest_hparams( + config: Dict, trial: optuna.trial.Trial, basename: str = "" +) -> Dict: + """ + Suggest hyperparameters using Optuna's suggest methods. + + :param config: Configuration dictionary with hyperparameter specifications. + :type config: Dict + :param trial: Optuna trial instance. + :type trial: optuna.trial.Trial + :param basename: Base name for hyperparameter suggestions. + :type basename: str + :return: Updated configuration with suggested hyperparameters. + :rtype: Dict + """ config = copy.deepcopy(config) for k, cfg in config.items(): if isinstance(cfg, dict): @@ -75,6 +121,14 @@ def suggest_hparams(config: Dict, trial: optuna.trial.Trial, basename: str = "") def compute_se(errors: np.ndarray) -> float: + """ + Compute standard error. + + :param errors: Array of errors. + :type errors: np.ndarray + :return: Standard error. + :rtype: float + """ mean_error = np.mean(errors) n = errors.shape[0] error_variance = np.mean((errors - mean_error) ** 2) / (n - 1) * n @@ -82,6 +136,14 @@ def compute_se(errors: np.ndarray) -> float: def standardize_series(x: pd.Series) -> pd.Series: + """ + Standardize a pandas Series. + + :param x: Input Series. + :type x: pd.Series + :return: Standardized Series. + :rtype: pd.Series + """ x = x.astype(np.float32) mean = x.mean() variance = ((x - mean) ** 2).mean() @@ -91,7 +153,17 @@ def standardize_series(x: pd.Series) -> pd.Series: def my_quantile_transform(x, seed=1): """ - returns Gaussian quantile transformed values, "nan" are kept + Gaussian quantile transform for values in a pandas Series. + + :param x: Input pandas Series. + :type x: pd.Series + :param seed: Random seed. + :type seed: int + :return: Transformed Series. + :rtype: pd.Series + + .. note:: + "nan" values are kept """ np.random.seed(seed) x_transform = x.copy().to_numpy() @@ -110,11 +182,31 @@ def my_quantile_transform(x, seed=1): def standardize_series_with_params(x: pd.Series, std, mean) -> pd.Series: + """ + Standardize a pandas Series using provided standard deviation and mean. + + :param x: Input Series. + :type x: pd.Series + :param std: Standard deviation to use for standardization. + :param mean: Mean to use for standardization. + :return: Standardized Series. + :rtype: pd.Series + """ x = x.apply(lambda x: (x - mean) / std if x != 0 else 0) return x def calculate_mean_std(x: pd.Series, ignore_zero=True) -> pd.Series: + """ + Calculate mean and standard deviation of a pandas Series. + + :param x: Input Series. + :type x: pd.Series + :param ignore_zero: Whether to ignore zero values in calculations, defaults to True. + :type ignore_zero: bool + :return: Tuple of standard deviation and mean. + :rtype: Tuple[float, float] + """ x = x.astype(np.float32) if ignore_zero: x = x[x != float(0)] @@ -130,6 +222,22 @@ def safe_merge( validate: str = "1:1", equal_row_nums: bool = False, ): + """ + Safely merge two pandas DataFrames. + + :param left: Left DataFrame. + :type left: pd.DataFrame + :param right: Right DataFrame. + :type right: pd.DataFrame + :param validate: Validation method for the merge. + :type validate: str + :param equal_row_nums: Whether to check if the row numbers are equal, defaults to False. + :type equal_row_nums: bool + :raises ValueError: If left and right dataframe rows are unequal when 'equal_row_nums' is True. + :raises RuntimeError: If merged DataFrame has unequal row numbers compared to the left DataFrame. + :return: Merged DataFrame. + :rtype: pd.DataFrame + """ if equal_row_nums: try: assert len(left) == len(right) @@ -153,6 +261,14 @@ def safe_merge( def resolve_path_with_env(path: str) -> str: + """ + Resolve a path with environment variables. + + :param path: Input path. + :type path: str + :return: Resolved path. + :rtype: str + """ path_split = [] head = path while head not in ("", "/"): @@ -168,6 +284,16 @@ def resolve_path_with_env(path: str) -> str: def copy_with_env(path: str, destination: str) -> str: + """ + Copy a file or directory to a destination with environment variables. + + :param path: Input path (file or directory). + :type path: str + :param destination: Destination path. + :type destination: str + :return: Resulting destination path. + :rtype: str + """ destination = resolve_path_with_env(destination) if os.path.isfile(path): @@ -191,6 +317,16 @@ def copy_with_env(path: str, destination: str) -> str: def load_or_init(pickle_file: str, init_fn: Callable) -> Any: + """ + Load a pickled file or initialize an object. + + :param pickle_file: Pickle file path. + :type pickle_file: str + :param init_fn: Initialization function. + :type init_fn: Callable + :return: Loaded or initialized object. + :rtype: Any + """ if pickle_file is not None and os.path.isfile(pickle_file): logger.info(f"Using pickled file {pickle_file}") with open(pickle_file, "rb") as f: @@ -204,6 +340,16 @@ def load_or_init(pickle_file: str, init_fn: Callable) -> Any: def remove_prefix(string, prefix): + """ + Remove a prefix from a string. + + :param string: Input string. + :type string: str + :param prefix: Prefix to remove. + :type prefix: str + :return: String without the specified prefix. + :rtype: str + """ if string.startswith(prefix): return string[len(prefix) :] return string @@ -218,6 +364,18 @@ def suggest_batch_size( }, buffer_bytes: int = 2_500_000_000, ): + """ + Suggest a batch size for a tensor based on available GPU memory. + + :param tensor_shape: Shape of the tensor. + :type tensor_shape: Iterable[int] + :param example: Example dictionary with batch size, tensor shape, and max memory bytes. + :type example: Dict[str, Any] + :param buffer_bytes: Buffer bytes to consider. + :type buffer_bytes: int + :return: Suggested batch size for the given tensor shape and GPU memory. + :rtype: int + """ gpu_mem_bytes = torch.cuda.get_device_properties(0).total_memory batch_size = math.floor( example["batch_size"] diff --git a/deeprvat_env.yaml b/deeprvat_env.yaml index fa067c26..778713fe 100644 --- a/deeprvat_env.yaml +++ b/deeprvat_env.yaml @@ -30,6 +30,7 @@ dependencies: - tqdm=4.59 - zarr=2.13 - Cython=0.29 + - parallel=20230922 - pip=23.0.1 - plotnine=0.10.1 - pip: diff --git a/deeprvat_env_no_gpu.yml b/deeprvat_env_no_gpu.yml index 1560df75..b98d1cfb 100644 --- a/deeprvat_env_no_gpu.yml +++ b/deeprvat_env_no_gpu.yml @@ -27,6 +27,7 @@ dependencies: - tqdm=4.59 - zarr=2.13 - Cython=0.29 + - parallel=20230922 - pip=23.1.2 - plotnine=0.10.1 - pip: diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d4bb2cbb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/annotation_pipeline_dag.png b/docs/_static/annotation_pipeline_dag.png new file mode 100644 index 00000000..0a088178 Binary files /dev/null and b/docs/_static/annotation_pipeline_dag.png differ diff --git a/docs/_static/preprocess_rulegraph_no_qc.svg b/docs/_static/preprocess_rulegraph_no_qc.svg new file mode 100644 index 00000000..3edc0a0e --- /dev/null +++ b/docs/_static/preprocess_rulegraph_no_qc.svg @@ -0,0 +1,169 @@ + + + + + + +snakemake_dag + + + +0 + +all + + + +1 + +combine_genotypes + + + +1->0 + + + + + +2 + +preprocess_no_qc + + + +2->1 + + + + + +3 + +add_variant_ids + + + +3->0 + + + + + +3->2 + + + + + +4 + +concatenate_variants + + + +4->3 + + + + + +9 + +create_parquet_variant_ids + + + +4->9 + + + + + +5 + +variants + + + +5->4 + + + + + +6 + +normalize + + + +6->5 + + + + + +10 + +sparsify + + + +6->10 + + + + + +7 + +extract_samples + + + +7->2 + + + + + +7->6 + + + + + +8 + +index_fasta + + + +8->6 + + + + + +9->0 + + + + + +9->2 + + + + + +10->2 + + + + + diff --git a/deeprvat/preprocessing/preprocess_rulegraph.svg b/docs/_static/preprocess_rulegraph_with_qc.svg similarity index 77% rename from deeprvat/preprocessing/preprocess_rulegraph.svg rename to docs/_static/preprocess_rulegraph_with_qc.svg index 22930d7a..5c3d86e5 100644 --- a/deeprvat/preprocessing/preprocess_rulegraph.svg +++ b/docs/_static/preprocess_rulegraph_with_qc.svg @@ -1,7 +1,7 @@ - 0 - + all 1 - + combine_genotypes - + 1->0 - + 2 - -preprocess + +preprocess_with_qc 2->1 - + 3 - + add_variant_ids - + 3->0 - + - + 3->2 - - + + 4 - + concatenate_variants 4->3 - + 9 - + create_parquet_variant_ids 4->9 - + 5 - + variants 5->4 - + 6 - + normalize 6->5 - + 10 - + sparsify 6->10 - + 11 - + qc_varmiss 6->11 - + 12 - + qc_hwe 6->12 - + 13 - + qc_read_depth 6->13 - + 14 - + qc_allelic_imbalance 6->14 - + 7 - + extract_samples - + 7->2 - - + + 7->6 - + 8 - + index_fasta 8->6 - + - + 9->0 - + - + 9->2 - - + + - + 10->2 - + - + 11->2 - + - + 12->2 - + - + 13->2 - + - + 14->2 - + 15 - + create_excluded_samples_dir - + 15->2 - - + + diff --git a/deeprvat/annotations/README.md b/docs/annotations.md similarity index 80% rename from deeprvat/annotations/README.md rename to docs/annotations.md index f1578798..a9532220 100644 --- a/deeprvat/annotations/README.md +++ b/docs/annotations.md @@ -1,15 +1,17 @@ # DeepRVAT Annotation pipeline -This pipeline is based on [snakemake](https://snakemake.readthedocs.io/en/stable/). It uses [bcftools + samstools](https://www.htslib.org/), as well as [perl](https://www.perl.org/), [deepRiPe](https://ohlerlab.mdc-berlin.de/software/DeepRiPe_140/) and [deepSEA](http://deepsea.princeton.edu/) as well as [VEP](http://www.ensembl.org/info/docs/tools/vep/index.html), including plugins for [primateAI](https://github.com/Illumina/PrimateAI) and [spliceAI](https://github.com/Illumina/SpliceAI). DeepRiPe annotations were acquired using [faatpipe repository by HealthML](https://github.com/HealthML/faatpipe)[[1]](#1) and DeepSea annotations were calculated using [kipoi-veff2](https://github.com/kipoi/kipoi-veff2)[[2]](#2), abSplice scores were computet using [abSplice](https://github.com/gagneurlab/absplice/)[[3]](#3) +This pipeline is based on [snakemake](https://snakemake.readthedocs.io/en/stable/). It uses [bcftools + samstools](https://www.htslib.org/), as well as [perl](https://www.perl.org/), [deepRiPe](https://ohlerlab.mdc-berlin.de/software/DeepRiPe_140/) and [deepSEA](http://deepsea.princeton.edu/) as well as [VEP](http://www.ensembl.org/info/docs/tools/vep/index.html), including plugins for [primateAI](https://github.com/Illumina/PrimateAI) and [spliceAI](https://github.com/Illumina/SpliceAI). DeepRiPe annotations were acquired using [faatpipe repository by HealthML](https://github.com/HealthML/faatpipe)[[1]](#reference-1-target) and DeepSea annotations were calculated using [kipoi-veff2](https://github.com/kipoi/kipoi-veff2)[[2]](#reference-2-target), abSplice scores were computet using [abSplice](https://github.com/gagneurlab/absplice/)[[3]](#reference-3-target) -![dag](https://github.com/PMBio/deeprvat/assets/23211603/d483831e-3558-4e21-9845-4b62ad4eecc3) +![dag](_static/annotation_pipeline_dag.png) *Figure 1: Example DAG of annoation pipeline using only two bcf files as input.* ## Input -The pipeline uses left-normalized bcf files containing variant information, a reference fasta file as well as a text file that maps data blocks to chromosomes as input. It is expected that the bcf files contain the columns "CHROM" "POS" "ID" "REF" and "ALT". Any other columns, including genotype information are stripped from the data before annotation tools are used on the data. The variants may be split into several vcf files for each chromosome and each "block" of data. The filenames should then contain the corresponding chromosome and block number. The pattern of the file names, as well as file structure may be specified in the corresponding [config file](config/deeprvat_annotation_config.yaml). +The pipeline uses left-normalized bcf files containing variant information, a reference fasta file as well as a text file that maps data blocks to chromosomes as input. It is expected that the bcf files contain the columns "CHROM" "POS" "ID" "REF" and "ALT". Any other columns, including genotype information are stripped from the data before annotation tools are used on the data. The variants may be split into several vcf files for each chromosome and each "block" of data. The filenames should then contain the corresponding chromosome and block number. The pattern of the file names, as well as file structure may be specified in the corresponding [config file](https://github.com/PMBio/deeprvat/blob/main/pipelines/config/deeprvat_annotation_config.yaml). +(requirements-target)= ## Requirements + BCFtools as well as HTSlib should be installed on the machine, - [CADD](https://github.com/kircherlab/CADD-scripts/tree/master/src/scripts) as well as - [VEP](http://www.ensembl.org/info/docs/tools/vep/script/vep_download.html), @@ -18,9 +20,9 @@ BCFtools as well as HTSlib should be installed on the machine, - [faatpipe](https://github.com/HealthML/faatpipe), and the - [vep-plugins repository](https://github.com/Ensembl/VEP_plugins/) -will be installed by the pipeline together with the [plugins](https://www.ensembl.org/info/docs/tools/vep/script/vep_plugins.html) for primateAI and spliceAI. Annotation data for CADD, spliceAI and primateAI should be downloaded. The path to the data may be specified in the corresponding [config file](config/deeprvat_annotation_config.yaml). +will be installed by the pipeline together with the [plugins](https://www.ensembl.org/info/docs/tools/vep/script/vep_plugins.html) for primateAI and spliceAI. Annotation data for CADD, spliceAI and primateAI should be downloaded. The path to the data may be specified in the corresponding [config file](https://github.com/PMBio/deeprvat/blob/main/pipelines/config/deeprvat_annotation_config.yaml). Download path: -- [CADD](http://cadd.gs.washington.edu/download): "All possible SNVs of GRCh38/hg38" and "gnomad.genomes.r3.0.indel.tsv.gz" incl. their Tabix Indices +- [CADD](https://cadd.bihealth.org/download): "All possible SNVs of GRCh38/hg38" and "gnomad.genomes.r3.0.indel.tsv.gz" incl. their Tabix Indices - [SpliceAI](https://basespace.illumina.com/s/otSPW8hnhaZR): "genome_scores_v1.3"/"spliceai_scores.raw.snv.hg38.vcf.gz" and "spliceai_scores.raw.indel.hg38.vcf.gz" - [PrimateAI](https://basespace.illumina.com/s/yYGFdGih1rXL) PrimateAI supplementary data/"PrimateAI_scores_v0.2_GRCh38_sorted.tsv.bgz" @@ -30,7 +32,7 @@ Download path: The pipeline outputs one annotation file for VEP, CADD, DeepRiPe, DeepSea and Absplice for each input vcf-file. The tool further creates concatenated files for each tool and one merged file containing Scores from AbSplice, VEP incl. CADD, primateAI and spliceAI as well as principal components from DeepSea and DeepRiPe. ## Configure the annotation pipeline -The snakemake annotation pipeline is configured using a yaml file with the format akin to the [example file](config/deeprvat_annotation_config.yaml). +The snakemake annotation pipeline is configured using a yaml file with the format akin to the [example file](https://github.com/PMBio/deeprvat/blob/main/pipelines/config/deeprvat_annotation_config.yaml). The config above would use the following directory structure: ```shell @@ -73,7 +75,7 @@ The config above would use the following directory structure: ``` -Bcf files created by the [preprocessing pipeline](https://github.com/PMBio/deeprvat/blob/Annotations/deeprvat/preprocessing/README.md) are used as input data. +Bcf files created by the [preprocessing pipeline](preprocessing.md) are used as input data. The pipeline also uses the variant.tsv file as well as the reference file from the preprocesing pipeline. The pipeline beginns by installing the repositories needed for the annotations, it will automatically install all repositories in the `repo_dir` folder that can be specified in the config file relative to the annotation working directory. The text file mapping blocks to chromosomes is stored in `metadata` folder. The output is stored in the `output_dir/annotations` folder and any temporary files in the `tmp` subfolder. All repositories used including VEP with its corresponding cache as well as plugins are stored in `repo_dir/ensempl-vep`. @@ -81,27 +83,27 @@ Data for VEP plugins and the CADD cache are stored in `annotation data`. ## Running the annotation pipeline ### Preconfiguration -- Inside the annotation directory create a directory `repo_dir` and run the [annotation setup script](setup_annotation_workflow.sh) +- Inside the annotation directory create a directory `repo_dir` and run the [annotation setup script](https://github.com/PMBio/deeprvat/blob/main/deeprvat/annotations/setup_annotation_workflow.sh) ```shell setup_annotation_workflow.sh repo_dir/ensembl-vep/cache repo_dir/ensembl-vep/Plugins repo_dir ``` - or manually clone the repositories mentioned in the [requirements](#requirements) into `repo_dir` and install the needed conda environments with + or manually clone the repositories mentioned in the [requirements](#requirements-target) into `repo_dir` and install the needed conda environments with ```shell mamba env create -f repo_dir/absplice/environment.yaml mamba env create -f repo_dir/kipoi-veff2/environment.minimal.linux.yml mamba env create -f deeprvat/deeprvat_annotations.yml ``` - If you already have some of the needed repositories on your machine you can edit the paths in the [config](../../pipelines/config/deeprvat_annotation_config.yaml). + If you already have some of the needed repositories on your machine you can edit the paths in the [config](https://github.com/PMBio/deeprvat/blob/main/pipelines/config/deeprvat_annotation_config.yaml). -- Inside the annotation directory create a directory `annotation_dir` and download/link the prescored files for CADD, SpliceAI, and PrimateAI (see [requirements](#requirements)) +- Inside the annotation directory create a directory `annotation_dir` and download/link the prescored files for CADD, SpliceAI, and PrimateAI (see [requirements](#requirements-target)) ### Running the pipeline After configuration and activating the `deeprvat_annotations` environment run the pipeline using snakemake: ```shell - snakemake -j -s annotations.snakemake --configfile config/deeprvat_annotation.config + snakemake -j -s annotations.snakemake --configfile config/deeprvat_annotation.config --use-conda ``` ## Running the annotation pipeline without the preprocessing pipeline @@ -113,8 +115,13 @@ However, the annotation pipeline requires some files from this pipeline that the ## References + +(reference-1-target)= [1] Monti, R., Rautenstrauch, P., Ghanbari, M. et al. Identifying interpretable gene-biomarker associations with functionally informed kernel-based tests in 190,000 exomes. Nat Commun 13, 5332 (2022). https://doi.org/10.1038/s41467-022-32864-2 +(reference-2-target)= [2] Žiga Avsec et al., “Kipoi: accelerating the community exchange and reuse of predictive models for genomics,” bioRxiv, p. 375345, Jan. 2018, doi: 10.1101/375345. +(reference-3-target)= [3]N. Wagner et al., “Aberrant splicing prediction across human tissues,” Nature Genetics, vol. 55, no. 5, pp. 861–870, May 2023, doi: 10.1038/s41588-023-01373-3. + diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..f96b483f --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,33 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +from datetime import datetime + +project = "DeepRVAT" +copyright = f"{datetime.now().year}, Clarke, B., Holtkamp, E., Öztürk, H., Mück, M., Wahlberg, M., Meyer, K., Brechtmann, F., Hölzlwimmer, F. R., Gagneur, J., & Stegle, O" +author = "Clarke, B., Holtkamp, E., Öztürk, H., Mück, M., Wahlberg, M., Meyer, K., Brechtmann, F., Hölzlwimmer, F. R., Gagneur, J., & Stegle, O" +version = "0.1.0" +release = version + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = ["autodoc2", "myst_parser", "sphinx_copybutton"] +autodoc2_packages = [ + "../deeprvat", +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..fcabaffc --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,28 @@ +.. DeepRVAT documentation master file, created by + sphinx-quickstart on Wed Nov 22 10:24:36 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to DeepRVAT's documentation! +==================================== + +Rare variant association testing using deep learning and data-driven burden scores + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + usage.md + preprocessing.md + annotations.md + seed_gene_discovery.md + apidocs/index + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..32bb2452 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/preprocessing.md b/docs/preprocessing.md new file mode 100644 index 00000000..502b1e62 --- /dev/null +++ b/docs/preprocessing.md @@ -0,0 +1,232 @@ +# DeepRVAT Preprocessing pipeline + +The DeepRVAT preprocessing pipeline is based on [snakemake](https://snakemake.readthedocs.io/en/stable/) it uses +[bcftools+samstools](https://www.htslib.org/) and a [python script](https://github.com/PMBio/deeprvat/blob/main/deeprvat/preprocessing/preprocess.py) preprocessing.py. + +![DeepRVAT preprocessing pipeline](_static/preprocess_rulegraph_no_qc.svg) + +## Output + +The important files that this pipeline produces that are needed in DeepRVAT are: + +- **preprocessed/genotypes.h5** *The main sparse hdf5 file* + +- **norm/variants/variants.parquet** *List of variants i parquet format* + +## Setup environment + +Create the DeepRVAT processing environment + +Clone this repository: + +```shell +git clone git@github.com:PMBio/deeprvat.git +``` + +Change directory to the repository: `cd deeprvat` + +```shell +mamba env create --file deeprvat_preprocessing_env.yml +``` + +Activate the environment + +```shell +mamba activate deeprvat_preprocess +``` + +Install DeepRVAT in the environment + +```shell +pip install -e . +``` + +## Configure preprocessing + +The snakemake preprocessing is configured using a yaml file with the format below. +An example file is included in this repo: [example config](https://github.com/PMBio/deeprvat/blob/main/pipelines/config/deeprvat_preprocess_config.yaml). + +```yaml +# What chromosomes should be processed +included_chromosomes : [21,22] + +# The format of the name of the "raw" vcf files +vcf_files_list: vcf_files_list.txt + +# Number of threads to use in the preprocessing script, separate from snakemake threads +preprocess_threads: 16 + +# If you need to run a cmd to load bcf and samtools specify it here, see example +bcftools_load_cmd : # module load bcftools/1.10.2 && +samtools_load_cmd : # module load samtools/1.9 && + +# Path to where you want to write results and intermediate data +working_dir: workdir + +# These paths are all relative to the working dir +# Here will the finished preprocessed files end up +preprocessed_dir_name : preprocesed +# Path to directory with fasta reference file +reference_dir_name : reference +# Here we will store normalized bcf files +norm_dir_name : norm +# Here we store "sparsified" bcf files +sparse_dir_name : sparse + +# Expected to be found in working_dir/reference_dir +reference_fasta_file : GRCh38.primary_assembly.genome.fa + +# You can specify a different zcat cmd for example gzcat here, default zcat +zcat_cmd: + ``` + +The config above would use the following directory structure: + +```shell +parent_directory +`-- workdir + |-- norm + | |-- bcf + | |-- sparse + | `-- variants + |-- preprocesed + |-- qc + | |-- allelic_imbalance + | |-- duplicate_vars + | |-- filtered_samples + | |-- hwe + | |-- indmiss + | | |-- samples + | | |-- sites + | | `-- stats + | |-- read_depth + | `-- varmiss + `-- reference + +``` + +### vcf_files_list +The `vcf_files_list` variable specifies the path to a text file that contains paths to the raw vcf files you want to +process. + +ex: + + +```text +data/vcf/test_vcf_data_c21_b1.vcf.gz +data/vcf/test_vcf_data_c22_b1.vcf.gz +``` + +The easiest way to create `vcf_files_list` (if you have your files in `data/vcf` under the `parent_directory`) +```shell +cd +find data/vcf -type f -name "*.vcf*" > vcf_files_list.txt +``` +## Running the preprocess pipeline + +There are two versions of the pipeline, one with qc (quality control) and one without, the version with qc is the one +we used when we wrote the paper. The qc is specific to the UKBB data, so if you want/need to do your own qc use the +pipeline without qc. + +### Run the preprocess pipeline with example data and qc +![DeepRVAT preprocessing pipeline](_static/preprocess_rulegraph_with_qc.svg) + +*The vcf files in the example data folder was generated using [fake-vcf](https://github.com/endast/fake-vcf) (with some +manual editing). +hence does not contain real data.* + +1. cd into the preprocessing example dir + +```shell +cd +cd example/preprocess +``` + +2. Download the fasta file + +```shell +wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/GRCh38.primary_assembly.genome.fa.gz -P workdir/reference +``` + +3. Unpack the fasta file + +```shell +gzip -d workdir/reference/GRCh38.primary_assembly.genome.fa.gz +``` + +4. Run with the example config + +```shell +snakemake -j 1 --snakefile ../../pipelines/preprocess_with_qc.snakefile --configfile ../../pipelines/config/deeprvat_preprocess_config.yaml +``` + +5. Enjoy the preprocessed data 🎉 + +```shell +ls -l workdir/preprocesed +total 48 +-rw-r--r-- 1 user staff 6404 Aug 2 14:06 genotypes.h5 +-rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr21.h5 +-rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr22.h5 +``` + + +### Run the preprocess pipeline with example data and no qc + +![DeepRVAT preprocessing pipeline](_static/preprocess_rulegraph_no_qc.svg) + +*The vcf files in the example data folder was generated using [fake-vcf](https://github.com/endast/fake-vcf) (with some +manual editing). +hence does not contain real data.* + +1. cd into the preprocessing example dir + +```shell +cd +cd example/preprocess +``` + +2. Download the fasta file + +```shell +wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/GRCh38.primary_assembly.genome.fa.gz -P workdir/reference +``` + +3. Unpack the fasta file + +```shell +gzip -d workdir/reference/GRCh38.primary_assembly.genome.fa.gz +``` + +4. Run with the example config + +```shell +snakemake -j 1 --snakefile ../../pipelines/preprocess_no_qc.snakefile --configfile ../../pipelines/config/deeprvat_preprocess_config.yaml +``` + +5. Enjoy the preprocessed data 🎉 + +```shell +ls -l workdir/preprocesed +total 48 +-rw-r--r-- 1 user staff 6404 Aug 2 14:06 genotypes.h5 +-rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr21.h5 +-rw-r--r-- 1 user staff 6354 Aug 2 14:06 genotypes_chr22.h5 +``` + +### Run on your own data with qc + +After configuration and activating the environment run the pipeline using snakemake: + +```shell +snakemake -j --configfile config/deeprvat_preprocess_config.yaml -s preprocess_with_qc.snakefile +``` + + +### Run on your own data without qc + +After configuration and activating the environment run the pipeline using snakemake: + +```shell +snakemake -j --configfile config/deeprvat_preprocess_config.yaml -s preprocess_no_qc.snakefile +``` \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..97d35c58 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +sphinx==7.2.6 +myst-parser==2.0.0 +sphinx-autodoc2==0.4.2 +astroid==2.15.8 +sphinx-copybutton==0.5.2 +sphinx-rtd-theme==1.3.0 diff --git a/deeprvat/seed_gene_discovery/README.md b/docs/seed_gene_discovery.md similarity index 90% rename from deeprvat/seed_gene_discovery/README.md rename to docs/seed_gene_discovery.md index 1ae78f4a..c0fefda9 100644 --- a/deeprvat/seed_gene_discovery/README.md +++ b/docs/seed_gene_discovery.md @@ -6,7 +6,7 @@ To run the pipeline, an experiment directory with the `config.yaml` has to be cr ## Input data -The experiment directory in addition requires to have the same input data as specified for [DeepRVAT](https://github.com/PMBio/deeprvat/tree/main/README.md), including +The experiment directory in addition requires to have the same input data as specified for [DeepRVAT](usage.md), including - `annotations.parquet` - `protein_coding_genes.parquet` - `genotypes.h5` @@ -23,7 +23,7 @@ The `annotations.parquet` data frame should have the following columns: ### Run the seed gene discovery pipeline with example data -Create the conda environment and activate it, (instructions can be found in the [DeepRVAT README](https://github.com/PMBio/deeprvat/tree/main/README.md) ) +Create the conda environment and activate it, (instructions can be found here [DeepRVAT instructions](usage.md) ) ``` diff --git a/docs/usage.md b/docs/usage.md new file mode 100644 index 00000000..5d7c9170 --- /dev/null +++ b/docs/usage.md @@ -0,0 +1,85 @@ +# Using DeepRVAT + +## Installation + +1. Clone this repository: +```shell +git clone git@github.com:PMBio/deeprvat.git +``` +1. Change directory to the repository: `cd deeprvat` +1. Install the conda environment. We recommend using [mamba](https://mamba.readthedocs.io/en/latest/index.html), though you may also replace `mamba` with `conda` + + *note: [the current deeprvat env does not support cuda when installed with conda](https://github.com/PMBio/deeprvat/issues/16), install using mamba for cuda support.* +```shell +mamba env create -n deeprvat -f deeprvat_env.yaml +``` +1. Activate the environment: `mamba activate deeprvat` +1. Install the `deeprvat` package: `pip install -e .` + +If you don't want to install the gpu related requirements use the `deeprvat_env_no_gpu.yml` environment instead. +```shell +mamba env create -n deeprvat -f deeprvat_env_no_gpu.yaml +``` + + +## Basic usage + +### Customize pipelines + +Before running any of the snakefiles, you may want to adjust the number of threads used by different steps in the pipeline. To do this, modify the `threads:` property of a given rule. + +If you are running on a computing cluster, you will need a [profile](https://github.com/snakemake-profiles) and may need to add `resources:` directives to the snakefiles. + + +### Run the preprocessing pipeline on VCF files + +Instructions [here](preprocessing.md) + + +### Annotate variants + +Instructions [here](annotations.md) + + + +### Try the full training and association testing pipeline on some example data + +```shell +mkdir example +cd example +ln -s [path_to_deeprvat]/example/* . +snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/training_association_testing.snakefile +``` + +Replace `[path_to_deeprvat]` with the path to your clone of the repository. + +Note that the example data is randomly generated, and so is only suited for testing whether the `deeprvat` package has been correctly installed. + + +### Run the training pipeline on some example data + +```shell +mkdir example +cd example +ln -s [path_to_deeprvat]/example/* . +snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/run_training.snakefile +``` + +Replace `[path_to_deeprvat]` with the path to your clone of the repository. + +Note that the example data is randomly generated, and so is only suited for testing whether the `deeprvat` package has been correctly installed. + + +### Run the association testing pipeline with pretrained models + +```shell +mkdir example +cd example +ln -s [path_to_deeprvat]/example/* . +ln -s [path_to_deeprvat]/pretrained_models +snakemake -j 1 --snakefile [path_to_deeprvat]/pipelines/association_testing_pretrained.snakefile +``` + +Replace `[path_to_deeprvat]` with the path to your clone of the repository. + +Again, note that the example data is randomly generated, and so is only suited for testing whether the `deeprvat` package has been correctly installed. diff --git a/example/preprocess/data/metadata/pvcf_blocks.txt b/example/annotations/input_dir/vcf/metadata/pvcf_blocks.txt similarity index 100% rename from example/preprocess/data/metadata/pvcf_blocks.txt rename to example/annotations/input_dir/vcf/metadata/pvcf_blocks.txt diff --git a/example/annotations/input_dir/vcf/test_vcf_data_c21_b1.vcf.gz b/example/annotations/input_dir/vcf/test_vcf_data_c21_b1.vcf.gz new file mode 100644 index 00000000..df2edae1 Binary files /dev/null and b/example/annotations/input_dir/vcf/test_vcf_data_c21_b1.vcf.gz differ diff --git a/example/annotations/input_dir/vcf/test_vcf_data_c22_b1.vcf.gz b/example/annotations/input_dir/vcf/test_vcf_data_c22_b1.vcf.gz new file mode 100644 index 00000000..6228dc90 Binary files /dev/null and b/example/annotations/input_dir/vcf/test_vcf_data_c22_b1.vcf.gz differ diff --git a/example/annotations/output_dir/annotations/.gitkeep b/example/annotations/output_dir/annotations/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/example/annotations/preprocessing_workdir/norm/variants/variants.tsv.gz b/example/annotations/preprocessing_workdir/norm/variants/variants.tsv.gz new file mode 100644 index 00000000..e69de29b diff --git a/example/annotations/reference/hg38.fa b/example/annotations/reference/hg38.fa new file mode 100644 index 00000000..e69de29b diff --git a/example/baseline_results/Apolipoprotein_A/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_A/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_A/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_A/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_A/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_A/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_A/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_A/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_A/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_A/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_A/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_A/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_A/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_A/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_A/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_A/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_B/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_B/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_B/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_B/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_B/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_B/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_B/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_B/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_B/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_B/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_B/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_B/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Apolipoprotein_B/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Apolipoprotein_B/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Apolipoprotein_B/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Apolipoprotein_B/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Calcium/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Calcium/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Calcium/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Calcium/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Calcium/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Calcium/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Calcium/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Calcium/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Calcium/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Calcium/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Calcium/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Calcium/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Calcium/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Calcium/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Calcium/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Calcium/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Cholesterol/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Cholesterol/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Cholesterol/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Cholesterol/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Cholesterol/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Cholesterol/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Cholesterol/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Cholesterol/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Cholesterol/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Cholesterol/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Cholesterol/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Cholesterol/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Cholesterol/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Cholesterol/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Cholesterol/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Cholesterol/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/HDL_cholesterol/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/HDL_cholesterol/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/HDL_cholesterol/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/HDL_cholesterol/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/HDL_cholesterol/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/HDL_cholesterol/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/HDL_cholesterol/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/HDL_cholesterol/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/HDL_cholesterol/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/HDL_cholesterol/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/HDL_cholesterol/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/HDL_cholesterol/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/HDL_cholesterol/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/HDL_cholesterol/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/HDL_cholesterol/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/HDL_cholesterol/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/IGF_1/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/IGF_1/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/IGF_1/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/IGF_1/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/IGF_1/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/IGF_1/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/IGF_1/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/IGF_1/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/IGF_1/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/IGF_1/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/IGF_1/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/IGF_1/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/IGF_1/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/IGF_1/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/IGF_1/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/IGF_1/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/LDL_direct/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/LDL_direct/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/LDL_direct/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/LDL_direct/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/LDL_direct/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/LDL_direct/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/LDL_direct/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/LDL_direct/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/LDL_direct/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/LDL_direct/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/LDL_direct/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/LDL_direct/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/LDL_direct/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/LDL_direct/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/LDL_direct/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/LDL_direct/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Lymphocyte_percentage/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Lymphocyte_percentage/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Lymphocyte_percentage/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Lymphocyte_percentage/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Lymphocyte_percentage/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Lymphocyte_percentage/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Lymphocyte_percentage/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Lymphocyte_percentage/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Lymphocyte_percentage/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Lymphocyte_percentage/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Lymphocyte_percentage/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Lymphocyte_percentage/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Lymphocyte_percentage/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Lymphocyte_percentage/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Lymphocyte_percentage/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Lymphocyte_percentage/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_corpuscular_volume/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_corpuscular_volume/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_corpuscular_volume/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_corpuscular_volume/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_corpuscular_volume/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_corpuscular_volume/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_corpuscular_volume/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_corpuscular_volume/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_corpuscular_volume/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_corpuscular_volume/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_corpuscular_volume/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_corpuscular_volume/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_corpuscular_volume/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_corpuscular_volume/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_corpuscular_volume/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_corpuscular_volume/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_platelet_thrombocyte_volume/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_platelet_thrombocyte_volume/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_platelet_thrombocyte_volume/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_platelet_thrombocyte_volume/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_platelet_thrombocyte_volume/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_platelet_thrombocyte_volume/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_platelet_thrombocyte_volume/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_platelet_thrombocyte_volume/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_platelet_thrombocyte_volume/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_platelet_thrombocyte_volume/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_platelet_thrombocyte_volume/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_platelet_thrombocyte_volume/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_platelet_thrombocyte_volume/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_platelet_thrombocyte_volume/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_platelet_thrombocyte_volume/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_platelet_thrombocyte_volume/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_reticulocyte_volume/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_reticulocyte_volume/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_reticulocyte_volume/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_reticulocyte_volume/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_reticulocyte_volume/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_reticulocyte_volume/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_reticulocyte_volume/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_reticulocyte_volume/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_reticulocyte_volume/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_reticulocyte_volume/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_reticulocyte_volume/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_reticulocyte_volume/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Mean_reticulocyte_volume/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Mean_reticulocyte_volume/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Mean_reticulocyte_volume/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Mean_reticulocyte_volume/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Neutrophill_count/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Neutrophill_count/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Neutrophill_count/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Neutrophill_count/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Neutrophill_count/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Neutrophill_count/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Neutrophill_count/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Neutrophill_count/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Neutrophill_count/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Neutrophill_count/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Neutrophill_count/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Neutrophill_count/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Neutrophill_count/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Neutrophill_count/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Neutrophill_count/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Neutrophill_count/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_count/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_count/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_count/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_count/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_count/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_count/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_count/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_count/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_count/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_count/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_count/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_count/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_count/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_count/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_count/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_count/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_crit/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_crit/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_crit/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_crit/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_crit/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_crit/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_crit/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_crit/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_crit/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_crit/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_crit/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_crit/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_crit/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_crit/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_crit/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_crit/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_distribution_width/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_distribution_width/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_distribution_width/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_distribution_width/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_distribution_width/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_distribution_width/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_distribution_width/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_distribution_width/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_distribution_width/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_distribution_width/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_distribution_width/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_distribution_width/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Platelet_distribution_width/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Platelet_distribution_width/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Platelet_distribution_width/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Platelet_distribution_width/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Red_blood_cell_erythrocyte_count/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Red_blood_cell_erythrocyte_count/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Red_blood_cell_erythrocyte_count/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Red_blood_cell_erythrocyte_count/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Red_blood_cell_erythrocyte_count/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Red_blood_cell_erythrocyte_count/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Red_blood_cell_erythrocyte_count/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Red_blood_cell_erythrocyte_count/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Red_blood_cell_erythrocyte_count/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Red_blood_cell_erythrocyte_count/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Red_blood_cell_erythrocyte_count/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Red_blood_cell_erythrocyte_count/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Red_blood_cell_erythrocyte_count/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Red_blood_cell_erythrocyte_count/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Red_blood_cell_erythrocyte_count/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Red_blood_cell_erythrocyte_count/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/SHBG/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/SHBG/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/SHBG/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/SHBG/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/SHBG/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/SHBG/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/SHBG/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/SHBG/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/SHBG/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/SHBG/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/SHBG/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/SHBG/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/SHBG/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/SHBG/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/SHBG/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/SHBG/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Standing_height/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Standing_height/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Standing_height/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Standing_height/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Standing_height/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Standing_height/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Standing_height/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Standing_height/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Standing_height/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Standing_height/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Standing_height/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Standing_height/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Standing_height/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Standing_height/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Standing_height/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Standing_height/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Total_bilirubin/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Total_bilirubin/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Total_bilirubin/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Total_bilirubin/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Total_bilirubin/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Total_bilirubin/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Total_bilirubin/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Total_bilirubin/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Total_bilirubin/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Total_bilirubin/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Total_bilirubin/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Total_bilirubin/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Total_bilirubin/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Total_bilirubin/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Total_bilirubin/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Total_bilirubin/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Triglycerides/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Triglycerides/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Triglycerides/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Triglycerides/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Triglycerides/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Triglycerides/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Triglycerides/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Triglycerides/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Triglycerides/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Triglycerides/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Triglycerides/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Triglycerides/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Triglycerides/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Triglycerides/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Triglycerides/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Triglycerides/plof/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Urate/missense/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Urate/missense/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Urate/missense/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Urate/missense/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Urate/missense/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Urate/missense/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Urate/missense/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Urate/missense/skat/eval/burden_associations.parquet diff --git a/example/baseline_results/Urate/plof/burden/eval/burden_associations_testing.parquet b/example/baseline_results/Urate/plof/burden/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Urate/plof/burden/eval/burden_associations_testing.parquet rename to example/baseline_results/Urate/plof/burden/eval/burden_associations.parquet diff --git a/example/baseline_results/Urate/plof/skat/eval/burden_associations_testing.parquet b/example/baseline_results/Urate/plof/skat/eval/burden_associations.parquet similarity index 100% rename from example/baseline_results/Urate/plof/skat/eval/burden_associations_testing.parquet rename to example/baseline_results/Urate/plof/skat/eval/burden_associations.parquet diff --git a/example/preprocess/data/vcf/test_vcf_data_c21_b1.vcf.gz b/example/preprocess/data/vcf/test_vcf_data_c21_b1.vcf.gz index f7c01f17..df2edae1 100644 Binary files a/example/preprocess/data/vcf/test_vcf_data_c21_b1.vcf.gz and b/example/preprocess/data/vcf/test_vcf_data_c21_b1.vcf.gz differ diff --git a/example/preprocess/data/vcf/test_vcf_data_c22_b1.vcf.gz b/example/preprocess/data/vcf/test_vcf_data_c22_b1.vcf.gz index 508766dc..6228dc90 100644 Binary files a/example/preprocess/data/vcf/test_vcf_data_c22_b1.vcf.gz and b/example/preprocess/data/vcf/test_vcf_data_c22_b1.vcf.gz differ diff --git a/example/preprocess/vcf_files_list.txt b/example/preprocess/vcf_files_list.txt new file mode 100644 index 00000000..6b196815 --- /dev/null +++ b/example/preprocess/vcf_files_list.txt @@ -0,0 +1,2 @@ +data/vcf/test_vcf_data_c21_b1.vcf.gz +data/vcf/test_vcf_data_c22_b1.vcf.gz diff --git a/lsf/config.yaml b/lsf/config.yaml new file mode 100644 index 00000000..d7bf68fd --- /dev/null +++ b/lsf/config.yaml @@ -0,0 +1,472 @@ +phenotypes: + Apolipoprotein_A: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Apolipoprotein_A + Apolipoprotein_B: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Apolipoprotein_B + Calcium: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Calcium + Cholesterol: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Cholesterol + Red_blood_cell_erythrocyte_count: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Red_blood_cell_erythrocyte_count + HDL_cholesterol: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: HDL_cholesterol + IGF_1: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: IGF_1 + LDL_direct: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: LDL_direct + Lymphocyte_percentage: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Lymphocyte_percentage + Mean_platelet_thrombocyte_volume: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Mean_platelet_thrombocyte_volume + Mean_corpuscular_volume: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Mean_corpuscular_volume + Mean_reticulocyte_volume: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Mean_reticulocyte_volume + Neutrophill_count: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Neutrophill_count + Platelet_count: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Platelet_count + Platelet_crit: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Platelet_crit + Platelet_distribution_width: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Platelet_distribution_width + SHBG: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: SHBG + Standing_height: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Standing_height + Total_bilirubin: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Total_bilirubin + Triglycerides: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Triglycerides + Urate: + correction_method: FDR + n_training_genes: 40 + baseline_phenotype: Urate + Body_mass_index_BMI: + correction_method: FDR + baseline_phenotype: Body_mass_index_BMI + Glucose: + correction_method: FDR + baseline_phenotype: Glucose + Vitamin_D: + correction_method: FDR + baseline_phenotype: Vitamin_D + Albumin: + correction_method: FDR + baseline_phenotype: Albumin + Total_protein: + correction_method: FDR + baseline_phenotype: Total_protein + Cystatin_C: + correction_method: FDR + baseline_phenotype: Cystatin_C + Gamma_glutamyltransferase: + correction_method: FDR + baseline_phenotype: Gamma_glutamyltransferase + Alkaline_phosphatase: + correction_method: FDR + baseline_phenotype: Alkaline_phosphatase + Creatinine: + correction_method: FDR + baseline_phenotype: Creatinine + Whole_body_fat_free_mass: + correction_method: FDR + baseline_phenotype: Whole_body_fat_free_mass + Forced_expiratory_volume_in_1_second_FEV1: + correction_method: FDR + baseline_phenotype: Forced_expiratory_volume_in_1_second_FEV1 + QTC_interval: + correction_method: FDR + baseline_phenotype: QTC_interval + Glycated_haemoglobin_HbA1c: + correction_method: FDR + baseline_phenotype: Glycated_haemoglobin_HbA1c + WHR: + correction_method: FDR + baseline_phenotype: WHR + WHR_Body_mass_index_BMI_corrected: + correction_method: FDR + baseline_phenotype: WHR_Body_mass_index_BMI_corrected + +baseline_results: + - + base: baseline_results + type: plof/burden + - + base: baseline_results + type: missense/burden + - + base: baseline_results + type: plof/skat + - + base: baseline_results + type: missense/skat + +alpha: 0.05 + +n_burden_chunks: 4 +n_regression_chunks: 2 + +n_repeats: 6 + +do_scoretest: True + +training: + min_variant_count: 1 + n_bags: 1 + drop_n_bags: 0 + train_proportion: 0.8 + sample_with_replacement: False + n_parallel_jobs: 6 + dataloader_config: + batch_size: 1024 + num_workers: 0 + temp_dir: $TMPDIR/deeprvat_train + cache_tensors: True + chunksize: 100 + phenotypes: + - Apolipoprotein_A + - Apolipoprotein_B + - Calcium + - Cholesterol + - Red_blood_cell_erythrocyte_count + - HDL_cholesterol + - IGF_1 + - LDL_direct + - Lymphocyte_percentage + - Mean_platelet_thrombocyte_volume + - Mean_corpuscular_volume + - Mean_reticulocyte_volume + - Neutrophill_count + - Platelet_count + - Platelet_crit + - Platelet_distribution_width + - SHBG + - Standing_height + - Total_bilirubin + - Triglycerides + - Urate + + +pl_trainer: + gpus: 1 + precision: 16 + min_epochs: 50 + max_epochs: 1000 + log_every_n_steps: 1 + check_val_every_n_epoch: 1 + +early_stopping: + mode: min + patience: 3 + min_delta: 0.00001 + verbose: True + +hyperparameter_optimization: + direction: maximize + n_trials: 1 + sampler: + type: TPESampler + config: {} + +model: + type: DeepSet + model_collection: agg_models + checkpoint: combined_agg.pt + config: + phi_layers: 2 + phi_hidden_dim: 20 + rho_layers: 3 + rho_hidden_dim: 10 + activation: LeakyReLU + pool: max + use_sigmoid: True + metrics: + objective: MSE + objective_mode: min + loss: MSE + all: + MSE: {} + PearsonCorrTorch: {} + MAE: {} + RSquared: {} + optimizer: + type: AdamW + config: {} + +training_data: + gt_file: genotypes.h5 + variant_file: variants.parquet + dataset_config: + min_common_af: + combined_UKB_NFE_AF: 0.01 + phenotype_file: phenotypes.parquet + y_transformation: quantile_transform + x_phenotypes: + - age + - genetic_sex + - genetic_PC_1 + - genetic_PC_2 + - genetic_PC_3 + - genetic_PC_4 + - genetic_PC_5 + - genetic_PC_6 + - genetic_PC_7 + - genetic_PC_8 + - genetic_PC_9 + - genetic_PC_10 + - genetic_PC_11 + - genetic_PC_12 + - genetic_PC_13 + - genetic_PC_14 + - genetic_PC_15 + - genetic_PC_16 + - genetic_PC_17 + - genetic_PC_18 + - genetic_PC_19 + - genetic_PC_20 + annotation_file: annotations.parquet + annotations: + - combined_UKB_NFE_AF + - combined_UKB_NFE_AF_MB + - CADD_PHRED + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + use_common_variants: False + use_rare_variants: True + rare_embedding: + type: PaddedAnnotations + config: + annotations: + - combined_UKB_NFE_AF_MB + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + thresholds: + combined_UKB_NFE_AF: "combined_UKB_NFE_AF < 1e-2" + CADD_PHRED: "CADD_PHRED > 5" + verbose: True + low_memory: True + verbose: True + dataloader_config: + batch_size: 64 + num_workers: 8 + +data: + gt_file: genotypes.h5 + variant_file: variants.parquet + dataset_config: + min_common_af: + combined_UKB_NFE_AF: 0.01 + phenotype_file: phenotypes.parquet + y_transformation: quantile_transform + x_phenotypes: + - age + - genetic_sex + - genetic_PC_1 + - genetic_PC_2 + - genetic_PC_3 + - genetic_PC_4 + - genetic_PC_5 + - genetic_PC_6 + - genetic_PC_7 + - genetic_PC_8 + - genetic_PC_9 + - genetic_PC_10 + - genetic_PC_11 + - genetic_PC_12 + - genetic_PC_13 + - genetic_PC_14 + - genetic_PC_15 + - genetic_PC_16 + - genetic_PC_17 + - genetic_PC_18 + - genetic_PC_19 + - genetic_PC_20 + annotation_file: annotations.parquet + annotations: + - combined_UKB_NFE_AF + - combined_UKB_NFE_AF_MB + - CADD_PHRED + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + gene_file: protein_coding_genes.parquet + use_common_variants: False + use_rare_variants: True + rare_embedding: + type: PaddedAnnotations + config: + annotations: + - combined_UKB_NFE_AF_MB + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + thresholds: + combined_UKB_NFE_AF: "combined_UKB_NFE_AF < 1e-3" + CADD_PHRED: "CADD_PHRED > 5" + gene_file: protein_coding_genes.parquet + verbose: True + low_memory: True + verbose: True + dataloader_config: + batch_size: 16 + num_workers: 10 diff --git a/lsf/lsf.yaml b/lsf/lsf.yaml new file mode 100644 index 00000000..cc972219 --- /dev/null +++ b/lsf/lsf.yaml @@ -0,0 +1,98 @@ +__default__: + - "-q medium" + - "-R \"select[(hname != 'odcf-cn11u15' && hname != 'odcf-cn31u13' && hname != 'odcf-cn31u21' && hname != 'odcf-cn23u23')]\"" + + +# For association testing pipelines + +config: + - "-q short" + +training_dataset: + - "-q long" + +delete_burden_cache: + - "-q short" + +choose_training_genes: + - "-q short" + +best_cv_run: + - "-q short" +link_avg_burdens: + - "-q short" +best_bagging_run: + - "-q short" + +train: + - "-q gpu" + - "-gpu num=1:gmem=10.7G" + - "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\"" + # - "-R tensorcore" + # - "-L /bin/bash" + +compute_burdens: + - "-q gpu" + - "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=15.7G" + - "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\"" + - "-W 180" + # - "-R tensorcore" + # - "-L /bin/bash" + +link_burdens: + - "-q medium" + +compute_plof_burdens: + - "-q medium" + +regress: + - "-q long" + +combine_regression_chunks: + - "-q short" + + +# For CV (phenotype prediction) pipeline + +deeprvat_config: + - "-q short" + +deeprvat_plof_config: + - "-q short" + +deeprvat_training_dataset: + - "-q long" + +deeprvat_delete_burden_cache: + - "-q short" + +deeprvat_best_cv_run: + - "-q short" + +deeprvat_train_cv: + - "-q gpu-lowprio" + - "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G" + - "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e071-gpu06')]\"" + # - "-R tensorcore" + # - "-L /bin/bash" + +deeprvat_train_bagging: + - "-q gpu-lowprio" + - "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G" + - "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\"" + # - "-R tensorcore" + # - "-L /bin/bash" + +deeprvat_compute_burdens: + - "-q gpu-lowprio" + - "-gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=10.7G" + - "-R \"select[(hname != 'e230-dgx2-1' && hname != 'e230-dgx2-2' && hname != 'e230-dgxa100-1' && hname != 'e230-dgxa100-2' && hname != 'e230-dgxa100-3' && hname != 'e230-dgxa100-4' && hname != 'e071-gpu06')]\"" + - "-W 180" + # - "-R tensorcore" + # - "-L /bin/bash" + +deeprvat_compute_plof_burdens: + - "-q medium" + +deeprvat_regress: + - "-q long" diff --git a/lsf/training_association_testing.snakefile b/lsf/training_association_testing.snakefile new file mode 100644 index 00000000..3d1ecdc6 --- /dev/null +++ b/lsf/training_association_testing.snakefile @@ -0,0 +1,385 @@ +from pathlib import Path + +configfile: 'config.yaml' + +debug_flag = config.get('debug', False) +phenotypes = config['phenotypes'] +phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) + +n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 +n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 +n_trials = config['hyperparameter_optimization']['n_trials'] +n_bags = config['training']['n_bags'] if not debug_flag else 3 +n_repeats = config['n_repeats'] +debug = '--debug ' if debug_flag else '' +do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' +tensor_compression_level = config['training'].get('tensor_compression_level', 1) +n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) + +wildcard_constraints: + repeat="\d+", + trial="\d+", + +rule all: + input: + expand("{phenotype}/deeprvat/eval/significant.parquet", + phenotype=phenotypes), + expand("{phenotype}/deeprvat/eval/all_results.parquet", + phenotype=phenotypes) + +rule evaluate: + input: + associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + repeat=range(n_repeats)), + config = '{phenotype}/deeprvat/hpopt_config.yaml', + output: + "{phenotype}/deeprvat/eval/significant.parquet", + "{phenotype}/deeprvat/eval/all_results.parquet" + threads: 1 + resources: + mem_mb = 16000, + load = 16000 + shell: + 'deeprvat_evaluate ' + + debug + + '--use-seed-genes ' + '--n-repeats {n_repeats} ' + '--correction-method FDR ' + '{input.associations} ' + '{input.config} ' + '{wildcards.phenotype}/deeprvat/eval' + +rule all_regression: + input: + expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), + +rule combine_regression_chunks: + input: + expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), + output: + '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + threads: 1 + resources: + mem_mb = 2048, + load = 2000 + shell: + 'deeprvat_associate combine-regression-results ' + '--model-name repeat_{wildcards.repeat} ' + '{input} ' + '{output}' + +rule regress: + input: + config = "{phenotype}/deeprvat/hpopt_config.yaml", + chunks = lambda wildcards: ( + [] if wildcards.phenotype == phenotypes[0] + else expand('{{phenotype}}/deeprvat/burdens/chunk{chunk}.linked', + chunk=range(n_burden_chunks)) + ), + phenotype_0_chunks = expand( + phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', + chunk=range(n_burden_chunks) + ), + output: + temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), + threads: 2 + resources: + mem_mb = lambda wildcards, attempt: 28676 + (attempt - 1) * 4098, + # mem_mb = 16000, + load = lambda wildcards, attempt: 28000 + (attempt - 1) * 4000 + shell: + 'deeprvat_associate regress ' + + debug + + '--chunk {wildcards.chunk} ' + '--n-chunks ' + str(n_regression_chunks) + ' ' + '--use-bias ' + '--repeat {wildcards.repeat} ' + + do_scoretest + + '{input.config} ' + '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats + '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' + +rule all_burdens: + input: + [ + (f'{p}/deeprvat/burdens/chunk{c}.' + + ("finished" if p == phenotypes[0] else "linked")) + for p in phenotypes + for c in range(n_burden_chunks) + ] + +rule link_burdens: + priority: 1 + input: + checkpoints = lambda wildcards: [ + f'models/repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = 'models/config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' + threads: 8 + resources: + mem_mb = lambda wildcards, attempt: 20480 + (attempt - 1) * 4098, + # mem_mb = 16000, + load = lambda wildcards, attempt: 16000 + (attempt - 1) * 4000 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule compute_burdens: + priority: 10 + input: + reversed = "models/reverse_finished.tmp", + checkpoints = lambda wildcards: [ + f'models/repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = 'models/config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' + threads: 8 + resources: + mem_mb = 2000000, # Using this value will tell our modified lsf.profile not to set a memory resource + load = 8000, + gpus = 1 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule all_association_dataset: + input: + expand('{phenotype}/deeprvat/association_dataset.pkl', + phenotype=phenotypes) + +rule association_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/association_dataset.pkl' + threads: 4 + resources: + mem_mb = lambda wildcards, attempt: 32000 * (attempt + 1), + load = 64000 + shell: + 'deeprvat_associate make-dataset ' + + debug + + '{input.config} ' + '{output}' + +rule reverse_models: + input: + checkpoints = expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + model_config = 'models/config.yaml', + data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", + output: + "models/reverse_finished.tmp" + threads: 4 + resources: + mem_mb = 20480, + load = 20480 + shell: + " && ".join([ + ("deeprvat_associate reverse-models " + "{input.model_config} " + "{input.data_config} " + "{input.checkpoints}"), + "touch {output}" + ]) + +rule all_training: + input: + expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + "models/config.yaml" + +rule link_config: + input: + 'models/repeat_0/config.yaml' + output: + "models/config.yaml" + threads: 1 + shell: + "ln -s repeat_0/config.yaml {output}" + +rule best_training_run: + input: + expand('models/repeat_{{repeat}}/trial{trial_number}/config.yaml', + trial_number=range(n_trials)), + output: + checkpoints = expand('models/repeat_{{repeat}}/best/bag_{bag}.ckpt', + bag=range(n_bags)), + config = 'models/repeat_{repeat}/config.yaml' + threads: 1 + resources: + mem_mb = 2048, + load = 2000 + shell: + ( + 'deeprvat_train best-training-run ' + + debug + + 'models/repeat_{wildcards.repeat} ' + 'models/repeat_{wildcards.repeat}/best ' + 'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db ' + '{output.config}' + ) + +rule train: + input: + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=training_phenotypes), + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes), + output: + expand('models/repeat_{repeat}/trial{trial_number}/config.yaml', + repeat=range(n_repeats), trial_number=range(n_trials)), + expand('models/repeat_{repeat}/trial{trial_number}/finished.tmp', + repeat=range(n_repeats), trial_number=range(n_trials)) + params: + phenotypes = " ".join( + [f"--phenotype {p} " + f"{p}/deeprvat/input_tensor.zarr " + f"{p}/deeprvat/covariates.zarr " + f"{p}/deeprvat/y.zarr" + for p in training_phenotypes]) + resources: + mem_mb = 2000000, # Using this value will tell our modified lsf.profile not to set a memory resource + load = 8000, + gpus = 1 + shell: + f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " + 'deeprvat_train train ' + + debug + + '--trial-id {{2}} ' + "{params.phenotypes} " + 'config.yaml ' + 'models/repeat_{{1}}/trial{{2}} ' + "models/repeat_{{1}}/hyperparameter_optimization.db '&&' " + "touch models/repeat_{{1}}/trial{{2}}/finished.tmp " + "::: " + " ".join(map(str, range(n_repeats))) + " " + "::: " + " ".join(map(str, range(n_trials))) + +rule all_training_dataset: + input: + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)) + +rule training_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml', + training_dataset = '{phenotype}/deeprvat/training_dataset.pkl' + output: + input_tensor = directory('{phenotype}/deeprvat/input_tensor.zarr'), + covariates = directory('{phenotype}/deeprvat/covariates.zarr'), + y = directory('{phenotype}/deeprvat/y.zarr') + threads: 8 + resources: + mem_mb = lambda wildcards, attempt: 64000 * (attempt + 1), + load = 16000 + priority: 50 + shell: + ( + 'deeprvat_train make-dataset ' + + debug + + '--compression-level ' + str(tensor_compression_level) + ' ' + '--training-dataset-file {input.training_dataset} ' + '{input.config} ' + '{output.input_tensor} ' + '{output.covariates} ' + '{output.y}' + ) + +rule training_dataset_pickle: + input: + '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/training_dataset.pkl' + threads: 1 + resources: + mem_mb = 40000, # lambda wildcards, attempt: 38000 + 12000 * attempt + load = 16000 + shell: + ( + 'deeprvat_train make-dataset ' + '--pickle-only ' + '--training-dataset-file {output} ' + '{input} ' + 'dummy dummy dummy' + ) + +rule all_config: + input: + seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', + phenotype=phenotypes), + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=phenotypes), + baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', + phenotype=phenotypes), + +rule config: + input: + config = 'config.yaml', + baseline = lambda wildcards: [ + str(Path(r['base']) / wildcards.phenotype / r['type'] / + 'eval/burden_associations.parquet') + for r in config['baseline_results'] + ] + output: + seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', + config = '{phenotype}/deeprvat/hpopt_config.yaml', + baseline = '{phenotype}/deeprvat/baseline_results.parquet', + threads: 1 + resources: + mem_mb = 1024, + load = 1000 + params: + baseline_results = lambda wildcards, input: ''.join([ + f'--baseline-results {b} ' + for b in input.baseline + ]) + shell: + ( + 'deeprvat_config update-config ' + '--phenotype {wildcards.phenotype} ' + '{params.baseline_results}' + '--baseline-results-out {output.baseline} ' + '--seed-genes-out {output.seed_genes} ' + '{input.config} ' + '{output.config}' + ) diff --git a/pipelines/annotations.snakefile b/pipelines/annotations.snakefile index 8f261cc6..a8a19116 100644 --- a/pipelines/annotations.snakefile +++ b/pipelines/annotations.snakefile @@ -388,11 +388,77 @@ rule deepRiPe_eclip_k5: shell: f"mkdir -p {pybedtools_tmp_path/'k5'} && python {annotation_python_file} scorevariants-deepripe {{input.variants}} {anno_dir} {{input.fasta}} {pybedtools_tmp_path/'k5'} {saved_deepripe_models_path} {{threads}} 'eclip_k5'" + output: + anno_dir / "all_variants.deepSea.csv", + shell: + " ".join( + [ + "python", + f"{annotation_python_file}", + "concatenate-deepripe", + "--included-chromosomes", + ",".join(included_chromosomes), + "--sep '\t'", + f"{anno_dir}", + str( + source_variant_file_pattern + ".CLI.deepseapredict.diff.tsv" + ).format(chr="{{chr}}", block="{{block}}"), + str(metadata_dir / config["pvcf_blocks_file"]), + str( + anno_dir / "all_variants.deepSea.csv", + ), + ] + ) + + +rule deepSea: + input: + variants=anno_tmp_dir + / (source_variant_file_pattern + "_variants_header.vcf.gz"), + fasta=fasta_dir / fasta_file_name, + output: + anno_dir / (source_variant_file_pattern + ".CLI.deepseapredict.diff.tsv"), + conda: + "kipoi-veff2" + shell: + "kipoi_veff2_predict {input.variants} {input.fasta} {output} -l 1000 -m 'DeepSEA/predict' -s 'diff'" + + +rule deepRiPe_parclip: + input: + variants=anno_tmp_dir / (source_variant_file_pattern + "_variants.vcf"), + fasta=fasta_dir / fasta_file_name, + output: + anno_dir / (source_variant_file_pattern + "_variants.parclip_deepripe.csv.gz"), + shell: + f"mkdir -p {pybedtools_tmp_path / 'parclip'} && python {annotation_python_file} scorevariants-deepripe {{input.variants}} {anno_dir} {{input.fasta}} {pybedtools_tmp_path / 'parclip'} {saved_deepripe_models_path} {{threads}} 'parclip'" + + +rule deepRiPe_eclip_hg2: + input: + variants=anno_tmp_dir / (source_variant_file_pattern + "_variants.vcf"), + fasta=fasta_dir / fasta_file_name, + output: + anno_dir / (source_variant_file_pattern + "_variants.eclip_hg2_deepripe.csv.gz"), + threads: lambda wildcards, attempt: n_jobs_deepripe * attempt + shell: + f"mkdir -p {pybedtools_tmp_path / 'hg2'} && python {annotation_python_file} scorevariants-deepripe {{input.variants}} {anno_dir} {{input.fasta}} {pybedtools_tmp_path / 'hg2'} {saved_deepripe_models_path} {{threads}} 'eclip_hg2'" + + +rule deepRiPe_eclip_k5: + input: + variants=anno_tmp_dir / (source_variant_file_pattern + "_variants.vcf"), + fasta=fasta_dir / fasta_file_name, + output: + anno_dir / (source_variant_file_pattern + "_variants.eclip_k5_deepripe.csv.gz"), + threads: lambda wildcards, attempt: n_jobs_deepripe * attempt + shell: + f"mkdir -p {pybedtools_tmp_path / 'k5'} && python {annotation_python_file} scorevariants-deepripe {{input.variants}} {anno_dir} {{input.fasta}} {pybedtools_tmp_path / 'k5'} {saved_deepripe_models_path} {{threads}} 'eclip_k5'" rule vep: input: - vcf=anno_tmp_dir / (vcf_pattern + "_stripped.vcf.gz"), + vcf=anno_tmp_dir / (source_variant_file_pattern + "_stripped.vcf.gz"), fasta=fasta_dir / fasta_file_name, output: anno_dir / (vcf_pattern + "_vep_anno.tsv"), @@ -467,9 +533,9 @@ rule extract_with_header: rule strip_chr_name: input: - anno_tmp_dir / (vcf_pattern + "_variants.vcf"), + anno_tmp_dir / (source_variant_file_pattern + "_variants.vcf"), output: - anno_tmp_dir / (vcf_pattern + "_stripped.vcf.gz"), + anno_tmp_dir / (source_variant_file_pattern + "_stripped.vcf.gz"), shell: f"{load_hts} cut -c 4- {{input}} |bgzip > {{output}}" @@ -478,7 +544,7 @@ rule extract_variants: input: bcf_dir / (vcf_pattern + ".bcf"), output: - anno_tmp_dir / (vcf_pattern + "_variants.vcf"), + anno_tmp_dir / (source_variant_file_pattern + "_variants.vcf"), shell: " ".join( [ diff --git a/pipelines/association_testing/association_dataset.snakefile b/pipelines/association_testing/association_dataset.snakefile new file mode 100644 index 00000000..0e63e53f --- /dev/null +++ b/pipelines/association_testing/association_dataset.snakefile @@ -0,0 +1,12 @@ + +rule association_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/association_dataset.pkl' + threads: 4 + shell: + 'deeprvat_associate make-dataset ' + + debug + + '{input.config} ' + '{output}' diff --git a/pipelines/association_testing/burdens.snakefile b/pipelines/association_testing/burdens.snakefile new file mode 100644 index 00000000..550390fa --- /dev/null +++ b/pipelines/association_testing/burdens.snakefile @@ -0,0 +1,74 @@ + +rule link_burdens: + priority: 1 + input: + checkpoints = lambda wildcards: [ + f'{model_path}/repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = model_path / 'config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' + threads: 8 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule compute_burdens: + priority: 10 + input: + reversed = model_path / "reverse_finished.tmp", + checkpoints = lambda wildcards: [ + model_path / f'repeat_{repeat}/best/bag_{bag}.ckpt' + for repeat in range(n_repeats) for bag in range(n_bags) + ], + dataset = '{phenotype}/deeprvat/association_dataset.pkl', + data_config = '{phenotype}/deeprvat/hpopt_config.yaml', + model_config = model_path / 'config.yaml', + output: + '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' + threads: 8 + shell: + ' && '.join([ + ('deeprvat_associate compute-burdens ' + + debug + + ' --n-chunks '+ str(n_burden_chunks) + ' ' + '--chunk {wildcards.chunk} ' + '--dataset-file {input.dataset} ' + '{input.data_config} ' + '{input.model_config} ' + '{input.checkpoints} ' + '{wildcards.phenotype}/deeprvat/burdens'), + 'touch {output}' + ]) + +rule reverse_models: + input: + checkpoints = expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + model_config = model_path / 'config.yaml', + data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", + output: + temp(model_path / "reverse_finished.tmp") + threads: 4 + shell: + " && ".join([ + ("deeprvat_associate reverse-models " + "{input.model_config} " + "{input.data_config} " + "{input.checkpoints}"), + "touch {output}" + ]) \ No newline at end of file diff --git a/pipelines/association_testing/regress_eval.snakefile b/pipelines/association_testing/regress_eval.snakefile new file mode 100644 index 00000000..bcb3f369 --- /dev/null +++ b/pipelines/association_testing/regress_eval.snakefile @@ -0,0 +1,63 @@ + +rule evaluate: + input: + associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + repeat=range(n_repeats)), + config = '{phenotype}/deeprvat/hpopt_config.yaml', + output: + "{phenotype}/deeprvat/eval/significant.parquet", + "{phenotype}/deeprvat/eval/all_results.parquet" + threads: 1 + shell: + 'deeprvat_evaluate ' + + debug + + '--use-seed-genes ' + '--n-repeats {n_repeats} ' + '--correction-method FDR ' + '{input.associations} ' + '{input.config} ' + '{wildcards.phenotype}/deeprvat/eval' + +rule all_regression: + input: + expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), + +rule combine_regression_chunks: + input: + expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), + output: + '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', + threads: 1 + shell: + 'deeprvat_associate combine-regression-results ' + '--model-name repeat_{wildcards.repeat} ' + '{input} ' + '{output}' + +rule regress: + input: + config = "{phenotype}/deeprvat/hpopt_config.yaml", + chunks = lambda wildcards: expand( + ('{{phenotype}}/deeprvat/burdens/chunk{chunk}.' + + ("finished" if wildcards.phenotype == phenotypes[0] else "linked")), + chunk=range(n_burden_chunks) + ), + phenotype_0_chunks = expand( + phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', + chunk=range(n_burden_chunks) + ), + output: + temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), + threads: 2 + shell: + 'deeprvat_associate regress ' + + debug + + '--chunk {wildcards.chunk} ' + '--n-chunks ' + str(n_regression_chunks) + ' ' + '--use-bias ' + '--repeat {wildcards.repeat} ' + + do_scoretest + + '{input.config} ' + '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats + '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' \ No newline at end of file diff --git a/pipelines/association_testing_pretrained.snakefile b/pipelines/association_testing_pretrained.snakefile index d8aac7b3..d7aaa006 100644 --- a/pipelines/association_testing_pretrained.snakefile +++ b/pipelines/association_testing_pretrained.snakefile @@ -5,19 +5,27 @@ configfile: 'config.yaml' debug_flag = config.get('debug', False) phenotypes = config['phenotypes'] phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 +n_trials = config['hyperparameter_optimization']['n_trials'] n_bags = config['training']['n_bags'] if not debug_flag else 3 n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' -pretrained_model_path = Path(config.get("pretrained_model_path", "pretrained_models")) +tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path(config.get("pretrained_model_path", "pretrained_models")) wildcard_constraints: repeat="\d+", trial="\d+", +include: "training/config.snakefile" +include: "association_testing/association_dataset.snakefile" +include: "association_testing/burdens.snakefile" +include: "association_testing/regress_eval.snakefile" + rule all: input: expand("{phenotype}/deeprvat/eval/significant.parquet", @@ -25,69 +33,6 @@ rule all: expand("{phenotype}/deeprvat/eval/all_results.parquet", phenotype=phenotypes) -rule evaluate: - input: - associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - repeat=range(n_repeats)), - config = '{phenotype}/deeprvat/hpopt_config.yaml', - output: - "{phenotype}/deeprvat/eval/significant.parquet", - "{phenotype}/deeprvat/eval/all_results.parquet" - threads: 1 - shell: - 'deeprvat_evaluate ' - + debug + - '--use-seed-genes ' - '--n-repeats {n_repeats} ' - '--correction-method FDR ' - '{input.associations} ' - '{input.config} ' - '{wildcards.phenotype}/deeprvat/eval' - -rule all_regression: - input: - expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), - -rule combine_regression_chunks: - input: - expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), - output: - '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - threads: 1 - shell: - 'deeprvat_associate combine-regression-results ' - '--model-name repeat_{wildcards.repeat} ' - '{input} ' - '{output}' - -rule regress: - input: - config = "{phenotype}/deeprvat/hpopt_config.yaml", - chunks = lambda wildcards: expand( - ('{{phenotype}}/deeprvat/burdens/chunk{chunk}.' + - ("finished" if wildcards.phenotype == phenotypes[0] else "linked")), - chunk=range(n_burden_chunks) - ), - phenotype_0_chunks = expand( - phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', - chunk=range(n_burden_chunks) - ), - output: - temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), - threads: 2 - shell: - 'deeprvat_associate regress ' - + debug + - '--chunk {wildcards.chunk} ' - '--n-chunks ' + str(n_regression_chunks) + ' ' - '--use-bias ' - '--repeat {wildcards.repeat} ' - + do_scoretest + - '{input.config} ' - '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats - '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' - rule all_burdens: input: [ @@ -97,97 +42,11 @@ rule all_burdens: for c in range(n_burden_chunks) ] -rule link_burdens: - priority: 1 - input: - checkpoints = lambda wildcards: [ - f'{pretrained_model_path}/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = pretrained_model_path / 'config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - -rule compute_burdens: - priority: 10 - input: - reversed = pretrained_model_path / "reverse_finished.tmp", - checkpoints = lambda wildcards: [ - pretrained_model_path / f'repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = pretrained_model_path / 'config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - rule all_association_dataset: input: expand('{phenotype}/deeprvat/association_dataset.pkl', phenotype=phenotypes) -rule association_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/association_dataset.pkl' - threads: 4 - shell: - 'deeprvat_associate make-dataset ' - + debug + - '{input.config} ' - '{output}' - -rule reverse_models: - input: - checkpoints = expand(pretrained_model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', - bag=range(n_bags), repeat=range(n_repeats)), - model_config = pretrained_model_path / 'config.yaml', - data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", - output: - temp(pretrained_model_path / "reverse_finished.tmp") - threads: 4 - shell: - " && ".join([ - ("deeprvat_associate reverse-models " - "{input.model_config} " - "{input.data_config} " - "{input.checkpoints}"), - "touch {output}" - ]) - rule all_config: input: seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', @@ -196,32 +55,3 @@ rule all_config: phenotype=phenotypes), baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', phenotype=phenotypes), - -rule config: - input: - config = 'config.yaml', - baseline = lambda wildcards: [ - str(Path(r['base']) / wildcards.phenotype / r['type'] / - 'eval/burden_associations_testing.parquet') - for r in config['baseline_results'] - ] - output: - seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', - config = '{phenotype}/deeprvat/hpopt_config.yaml', - baseline = '{phenotype}/deeprvat/baseline_results.parquet', - threads: 1 - params: - baseline_results = lambda wildcards, input: ''.join([ - f'--baseline-results {b} ' - for b in input.baseline - ]) - shell: - ( - 'deeprvat_config update-config ' - '--phenotype {wildcards.phenotype} ' - '{params.baseline_results}' - '--baseline-results-out {output.baseline} ' - '--seed-genes-out {output.seed_genes} ' - '{input.config} ' - '{output.config}' - ) diff --git a/pipelines/config/deeprvat_preprocess_config.yaml b/pipelines/config/deeprvat_preprocess_config.yaml index 8275d8cb..0f1b146c 100644 --- a/pipelines/config/deeprvat_preprocess_config.yaml +++ b/pipelines/config/deeprvat_preprocess_config.yaml @@ -1,21 +1,18 @@ # What chromosomes should be processed included_chromosomes : [21,22] +# The format of the name of the "raw" vcf files +vcf_files_list: vcf_files_list.txt + +# Number of threads to use in the preprocessing script, separate from snakemake threads +preprocess_threads: 16 + # If you need to run a cmd to load bcf and samtools specify it here, see example bcftools_load_cmd : # module load bcftools/1.10.2 && samtools_load_cmd : # module load samtools/1.9 && # Path to where you want to write results and intermediate data working_dir: workdir -# Path to ukbb data -data_dir: data - -# These paths are all relative to the data dir -input_vcf_dir_name : vcf -metadata_dir_name: metadata - -# expected to be found in the data_dir / metadata_dir -pvcf_blocks_file: pvcf_blocks.txt # These paths are all relative to the working dir # Here will the finished preprocessed files end up @@ -30,13 +27,5 @@ sparse_dir_name : sparse # Expected to be found in working_dir/reference_dir reference_fasta_file : GRCh38.primary_assembly.genome.fa -# The format of the name of the "raw" vcf files -vcf_filename_pattern: test_vcf_data_c{chr}_b{block} -# for ukbb data use this pattern: -#vcf_filename_pattern: ukb23156_c{chr}_b{block}_v1 - -# Number of threads to use in the preprocessing script, separate from snakemake threads -preprocess_threads: 16 - # You can specify a different zcat cmd for example gzcat here, default zcat zcat_cmd: \ No newline at end of file diff --git a/pipelines/preprocess.snakefile b/pipelines/preprocess.snakefile deleted file mode 100644 index 264702e8..00000000 --- a/pipelines/preprocess.snakefile +++ /dev/null @@ -1,327 +0,0 @@ -import pandas as pd -from pathlib import Path - - -configfile: "config/deeprvat_preprocess_config.yaml" - -load_samtools = config.get("samtools_load_cmd") or "" -load_bcftools = config.get("bcftools_load_cmd") or "" -zcat_cmd = config.get("zcat_cmd") or "zcat" - -preprocessing_cmd = "deeprvat_preprocess" - -working_dir = Path(config["working_dir"]) -data_dir = Path(config["data_dir"]) -preprocessed_dir = working_dir / config["preprocessed_dir_name"] -vcf_dir = data_dir / config["input_vcf_dir_name"] -metadata_dir = data_dir / config["metadata_dir_name"] -reference_dir = working_dir / config["reference_dir_name"] - -preprocess_threads = config["preprocess_threads"] - -fasta_file = reference_dir / config["reference_fasta_file"] -fasta_index_file = reference_dir / f"{config['reference_fasta_file']}.fai" - -norm_dir = working_dir / config["norm_dir_name"] -sparse_dir = norm_dir / config["sparse_dir_name"] -bcf_dir = norm_dir / "bcf" -norm_variants_dir = norm_dir / "variants" - -qc_dir = working_dir / "qc" -qc_indmiss_stats_dir = qc_dir / "indmiss/stats" -qc_indmiss_samples_dir = qc_dir / "indmiss/samples" -qc_indmiss_sites_dir = qc_dir / "indmiss/sites" -qc_varmiss_dir = qc_dir / "varmiss" -qc_hwe_dir = qc_dir / "hwe" -qc_read_depth_dir = qc_dir / "read_depth" -qc_allelic_imbalance_dir = qc_dir / "allelic_imbalance" -qc_duplicate_vars_dir = qc_dir / "duplicate_vars" -qc_filtered_samples_dir = qc_dir / "filtered_samples" - -vcf_filename_pattern = config["vcf_filename_pattern"] -vcf_files = vcf_dir / f"{vcf_filename_pattern}.vcf.gz" - -pvcf_blocks_df = pd.read_csv( - metadata_dir / config["pvcf_blocks_file"], - sep="\t", - header=None, - names=["Index", "Chromosome", "Block", "First position", "Last position"], - dtype={"Chromosome": str}, -).set_index("Index") - -# Filter out which chromosomes to work with -pvcf_blocks_df = pvcf_blocks_df[ - pvcf_blocks_df["Chromosome"].isin([str(c) for c in config["included_chromosomes"]]) -] - -chr_mapping = pd.Series( - [str(x) for x in range(1,23)] + ["X", "Y"],index=[str(x) for x in range(1,25)] -) -inv_chr_mapping = pd.Series( - [str(x) for x in range(1,25)],index=[str(x) for x in range(1,23)] + ["X", "Y"] -) - -pvcf_blocks_df["chr_name"] = chr_mapping.loc[pvcf_blocks_df["Chromosome"].values].values -chromosomes = pvcf_blocks_df["chr_name"] -block = pvcf_blocks_df["Block"] - - -rule all: - input: - preprocessed_dir / "genotypes.h5", - norm_variants_dir / "variants.tsv.gz", - variants=norm_variants_dir / "variants.parquet", - - -rule combine_genotypes: - input: - expand( - preprocessed_dir / "genotypes_chr{chr}.h5", - zip, - chr=chromosomes, - block=block, - ), - resources: mem_mb=15000 - output: - preprocessed_dir / "genotypes.h5", - shell: - f"{preprocessing_cmd} combine-genotypes {{input}} {{output}}" - - -rule create_excluded_samples_dir: - output: - directory(qc_filtered_samples_dir), - shell: - "mkdir -p {output}" - - -rule preprocess: - input: - variants=norm_variants_dir / "variants.tsv.gz", - variants_parquet=norm_variants_dir / "variants.parquet", - samples=norm_dir / "samples_chr.csv", - sparse_tg=expand( - sparse_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - qc_varmiss=expand( - qc_varmiss_dir / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - qc_hwe=expand( - qc_hwe_dir / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - qc_read_depth=expand( - qc_read_depth_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - qc_allelic_imbalance=expand( - qc_allelic_imbalance_dir / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - qc_filtered_samples=qc_filtered_samples_dir, - output: - expand(preprocessed_dir / "genotypes_chr{chr}.h5",chr=set(chromosomes)), - resources: mem_mb=15000 - shell: - " ".join( - [ - f"{preprocessing_cmd}", - "process-sparse-gt", - f"--exclude-variants {qc_allelic_imbalance_dir}", - f"--exclude-variants {qc_hwe_dir}", - f"--exclude-variants {qc_varmiss_dir}", - f"--exclude-variants {qc_duplicate_vars_dir}", - f"--exclude-calls {qc_read_depth_dir}", - f"--exclude-samples {qc_filtered_samples_dir}", - "--chromosomes ", - ",".join(str(chr) for chr in set(chromosomes)), - f"--threads {preprocess_threads}", - "{input.variants}", - "{input.samples}", - f"{sparse_dir}", - f"{preprocessed_dir / 'genotypes'}", - ] - ) - - -rule all_qc: - input: - expand( - [ - qc_varmiss_dir / f"{vcf_filename_pattern}.tsv.gz", - qc_hwe_dir / f"{vcf_filename_pattern}.tsv.gz", - qc_read_depth_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - qc_allelic_imbalance_dir / f"{vcf_filename_pattern}.tsv.gz", - ], - zip, - chr=chromosomes, - block=block, - ), - - -rule qc_varmiss: - input: - bcf_dir / "{vcf_filename_pattern}.bcf", - output: - qc_varmiss_dir / "{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=lambda wildcards, attempt: 256 * attempt, - shell: - f'{load_bcftools} bcftools query --format "%CHROM\t%POS\t%REF\t%ALT\n" --include "F_MISSING >= 0.1" {{input}} | gzip > {{output}}' - - -rule qc_hwe: - input: - bcf_dir / "{vcf_filename_pattern}.bcf", - output: - qc_hwe_dir / "{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=lambda wildcards, attempt: 256 * (attempt + 1), - shell: - f'{load_bcftools} bcftools +fill-tags --output-type u {{input}} -- --tags HWE | bcftools query --format "%CHROM\t%POS\t%REF\t%ALT\n" --include "INFO/HWE <= 1e-15" | gzip > {{output}}' - - -rule qc_read_depth: - input: - bcf_dir / f"{vcf_filename_pattern}.bcf", - output: - qc_read_depth_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=lambda wildcards, attempt: 256 * attempt, - shell: - f"""{load_bcftools} bcftools query --format '[%CHROM\\t%POS\\t%REF\\t%ALT\\t%SAMPLE\\n]' --include '(GT!="RR" & GT!="mis" & TYPE="snp" & FORMAT/DP < 7) | (GT!="RR" & GT!="mis" & TYPE="indel" & FORMAT/DP < 10)' {{input}} | gzip > {{output}}""" - - -rule qc_allelic_imbalance: - input: - bcf_dir / "{vcf_filename_pattern}.bcf", - output: - qc_allelic_imbalance_dir / "{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=lambda wildcards, attempt: 256 * attempt, - shell: - f"""{load_bcftools} bcftools query --format '%CHROM\t%POS\t%REF\t%ALT\n' --exclude 'COUNT(GT="het")=0 || (GT="het" & ((TYPE="snp" & (FORMAT/AD[*:1] / FORMAT/AD[*:0]) > 0.15) | (TYPE="indel" & (FORMAT/AD[*:1] / FORMAT/AD[*:0]) > 0.20)))' {{input}} | gzip > {{output}}""" - - -rule all_preprocess: - input: - expand( - [ - bcf_dir / f"{vcf_filename_pattern}.bcf", - sparse_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - norm_variants_dir / f"{vcf_filename_pattern}.tsv.gz", - ], - zip, - chr=chromosomes, - block=block, - ), - norm_variants_dir / "variants_no_id.tsv.gz", - norm_variants_dir / "variants.tsv.gz", - qc_duplicate_vars_dir / "duplicates.tsv", - - -rule normalize: - input: - vcf=vcf_files, - samplefile=norm_dir / "samples_chr.csv", - fasta=fasta_file, - fastaindex=fasta_index_file, - output: - bcf_dir / f"{vcf_filename_pattern}.bcf", - resources: - mem_mb=lambda wildcards, attempt: 16384 * (attempt + 1), - shell: - f"""{load_bcftools} bcftools view --samples-file {{input.samplefile}} --output-type u {{input.vcf}} | bcftools view --include 'COUNT(GT="alt") > 0' --output-type u | bcftools norm -m-both -f {{input.fasta}} --output-type b --output {{output}}""" - - -rule sparsify: - input: - bcf=bcf_dir / f"{vcf_filename_pattern}.bcf", - output: - tsv=sparse_dir / "chr{chr}" / f"{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=512, - shell: - f"""{load_bcftools} bcftools query --format '[%CHROM\t%POS\t%REF\t%ALT\t%SAMPLE\t%GT\n]' --include 'GT!="RR" & GT!="mis"' {{input.bcf}} \ - | sed 's/0[/,|]1/1/; s/1[/,|]0/1/; s/1[/,|]1/2/; s/0[/,|]0/0/' | gzip > {{output.tsv}}""" - - -rule variants: - input: - bcf=bcf_dir / f"{vcf_filename_pattern}.bcf", - output: - norm_variants_dir / f"{vcf_filename_pattern}.tsv.gz", - resources: - mem_mb=512, - shell: - f"{load_bcftools} bcftools query --format '%CHROM\t%POS\t%REF\t%ALT\n' {{input}} | gzip > {{output}}" - - -rule concatenate_variants: - input: - expand( - norm_variants_dir / f"{vcf_filename_pattern}.tsv.gz", - zip, - chr=chromosomes, - block=block, - ), - output: - norm_variants_dir / "variants_no_id.tsv.gz", - resources: - mem_mb=256, - shell: - "{zcat_cmd} {input} | gzip > {output}" - - -rule add_variant_ids: - input: - norm_variants_dir / "variants_no_id.tsv.gz", - output: - variants=norm_variants_dir / "variants.tsv.gz", - duplicates=qc_duplicate_vars_dir / "duplicates.tsv", - resources: - mem_mb=2048, - shell: - f"{preprocessing_cmd} add-variant-ids {{input}} {{output.variants}} {{output.duplicates}}" - - -rule create_parquet_variant_ids: - input: - norm_variants_dir / "variants_no_id.tsv.gz", - output: - variants=norm_variants_dir / "variants.parquet", - duplicates=qc_duplicate_vars_dir / "duplicates.parquet", - resources: - mem_mb=2048, - shell: - f"{preprocessing_cmd} add-variant-ids {{input}} {{output.variants}} {{output.duplicates}}" - - -rule extract_samples: - input: - expand(vcf_files,zip,chr=chromosomes,block=block), - output: - norm_dir / "samples_chr.csv", - shell: - f"{load_bcftools} bcftools query --list-samples {{input}} > {{output}}" - - -rule index_fasta: - input: - fasta=fasta_file, - output: - fasta_index_file, - shell: - f"{load_samtools} samtools faidx {{input.fasta}}" diff --git a/pipelines/preprocess_no_qc.snakefile b/pipelines/preprocess_no_qc.snakefile new file mode 100644 index 00000000..a98c60d6 --- /dev/null +++ b/pipelines/preprocess_no_qc.snakefile @@ -0,0 +1,33 @@ +include: "preprocessing/preprocess.snakefile" + + +rule all: + input: + preprocessed_dir / "genotypes.h5", + norm_variants_dir / "variants.tsv.gz", + variants=norm_variants_dir / "variants.parquet", + + +rule preprocess_no_qc: + input: + variants=norm_variants_dir / "variants.tsv.gz", + variants_parquet=norm_variants_dir / "variants.parquet", + samples=norm_dir / "samples_chr.csv", + sparse_tg=expand(sparse_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems), + output: + expand(preprocessed_dir / "genotypes_chr{chr}.h5", chr=chromosomes), + shell: + " ".join( + [ + f"{preprocessing_cmd}", + "process-sparse-gt", + f"--exclude-variants {qc_duplicate_vars_dir}", + "--chromosomes ", + ",".join(str(chr) for chr in set(chromosomes)), + f"--threads {preprocess_threads}", + "{input.variants}", + "{input.samples}", + f"{sparse_dir}", + f"{preprocessed_dir / 'genotypes'}", + ] + ) diff --git a/pipelines/preprocess_with_qc.snakefile b/pipelines/preprocess_with_qc.snakefile new file mode 100644 index 00000000..f0d2c465 --- /dev/null +++ b/pipelines/preprocess_with_qc.snakefile @@ -0,0 +1,49 @@ + +include: "preprocessing/preprocess.snakefile" +include: "preprocessing/qc.snakefile" + + +rule all: + input: + preprocessed_dir / "genotypes.h5", + norm_variants_dir / "variants.tsv.gz", + variants=norm_variants_dir / "variants.parquet", + + +rule preprocess_with_qc: + input: + variants=norm_variants_dir / "variants.tsv.gz", + variants_parquet=norm_variants_dir / "variants.parquet", + samples=norm_dir / "samples_chr.csv", + sparse_tg=expand(sparse_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems), + qc_varmiss=expand(qc_varmiss_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems), + qc_hwe=expand(qc_hwe_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems), + qc_read_depth=expand( + qc_read_depth_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems + ), + qc_allelic_imbalance=expand( + qc_allelic_imbalance_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems + ), + qc_filtered_samples=qc_filtered_samples_dir, + output: + expand(preprocessed_dir / "genotypes_chr{chr}.h5", chr=chromosomes), + shell: + " ".join( + [ + f"{preprocessing_cmd}", + "process-sparse-gt", + f"--exclude-variants {qc_allelic_imbalance_dir}", + f"--exclude-variants {qc_hwe_dir}", + f"--exclude-variants {qc_varmiss_dir}", + f"--exclude-variants {qc_duplicate_vars_dir}", + f"--exclude-calls {qc_read_depth_dir}", + f"--exclude-samples {qc_filtered_samples_dir}", + "--chromosomes ", + ",".join(str(chr) for chr in set(chromosomes)), + f"--threads {preprocess_threads}", + "{input.variants}", + "{input.samples}", + f"{sparse_dir}", + f"{preprocessed_dir / 'genotypes'}", + ] + ) diff --git a/pipelines/preprocessing/preprocess.snakefile b/pipelines/preprocessing/preprocess.snakefile new file mode 100644 index 00000000..bc9e4702 --- /dev/null +++ b/pipelines/preprocessing/preprocess.snakefile @@ -0,0 +1,138 @@ +from pathlib import Path + + +configfile: "config/deeprvat_preprocess_config.yaml" + + +load_samtools = config.get("samtools_load_cmd") or "" +load_bcftools = config.get("bcftools_load_cmd") or "" +zcat_cmd = config.get("zcat_cmd") or "zcat" + +preprocessing_cmd = "deeprvat_preprocess" + +working_dir = Path(config["working_dir"]) +preprocessed_dir = working_dir / config["preprocessed_dir_name"] +reference_dir = working_dir / config["reference_dir_name"] + +preprocess_threads = config["preprocess_threads"] + +fasta_file = reference_dir / config["reference_fasta_file"] +fasta_index_file = reference_dir / f"{config['reference_fasta_file']}.fai" + +norm_dir = working_dir / config["norm_dir_name"] +sparse_dir = norm_dir / config["sparse_dir_name"] +bcf_dir = norm_dir / "bcf" +norm_variants_dir = norm_dir / "variants" + +qc_dir = working_dir / "qc" +qc_indmiss_stats_dir = qc_dir / "indmiss/stats" +qc_indmiss_samples_dir = qc_dir / "indmiss/samples" +qc_indmiss_sites_dir = qc_dir / "indmiss/sites" +qc_varmiss_dir = qc_dir / "varmiss" +qc_hwe_dir = qc_dir / "hwe" +qc_read_depth_dir = qc_dir / "read_depth" +qc_allelic_imbalance_dir = qc_dir / "allelic_imbalance" +qc_duplicate_vars_dir = qc_dir / "duplicate_vars" +qc_filtered_samples_dir = qc_dir / "filtered_samples" + + +with open(config["vcf_files_list"]) as file: + vcf_files = [Path(line.rstrip()) for line in file] + vcf_stems = [vf.stem.replace(".vcf", "") for vf in vcf_files] + + assert len(vcf_stems) == len(vcf_files) + + vcf_look_up = {stem: file for stem, file in zip(vcf_stems, vcf_files)} + +chromosomes = config["included_chromosomes"] + + +rule combine_genotypes: + input: + expand( + preprocessed_dir / "genotypes_chr{chr}.h5", + chr=chromosomes, + ), + output: + preprocessed_dir / "genotypes.h5", + shell: + f"{preprocessing_cmd} combine-genotypes {{input}} {{output}}" + + +rule normalize: + input: + samplefile=norm_dir / "samples_chr.csv", + fasta=fasta_file, + fastaindex=fasta_index_file, + params: + vcf_file=lambda wildcards: vcf_look_up[wildcards.vcf_stem], + output: + bcf_file=bcf_dir / "{vcf_stem}.bcf", + shell: + f"""{load_bcftools} bcftools view --samples-file {{input.samplefile}} --output-type u {{params.vcf_file}} | bcftools view --include 'COUNT(GT="alt") > 0' --output-type u | bcftools norm -m-both -f {{input.fasta}} --output-type b --output {{output.bcf_file}}""" + + +rule index_fasta: + input: + fasta=fasta_file, + output: + fasta_index_file, + shell: + f"{load_samtools} samtools faidx {{input.fasta}}" + + +rule sparsify: + input: + bcf=bcf_dir / "{vcf_stem}.bcf", + output: + tsv=sparse_dir / "{vcf_stem}.tsv.gz", + shell: + f"""{load_bcftools} bcftools query --format '[%CHROM\t%POS\t%REF\t%ALT\t%SAMPLE\t%GT\n]' --include 'GT!="RR" & GT!="mis"' {{input.bcf}} \ + | sed 's/0[/,|]1/1/; s/1[/,|]0/1/; s/1[/,|]1/2/; s/0[/,|]0/0/' | gzip > {{output.tsv}}""" + + +rule variants: + input: + bcf=bcf_dir / "{vcf_stem}.bcf", + output: + norm_variants_dir / "{vcf_stem}.tsv.gz", + shell: + f"{load_bcftools} bcftools query --format '%CHROM\t%POS\t%REF\t%ALT\n' {{input}} | gzip > {{output}}" + + +rule concatenate_variants: + input: + expand(norm_variants_dir / "{vcf_stem}.tsv.gz", vcf_stem=vcf_stems), + output: + norm_variants_dir / "variants_no_id.tsv.gz", + shell: + "{zcat_cmd} {input} | gzip > {output}" + + +rule add_variant_ids: + input: + norm_variants_dir / "variants_no_id.tsv.gz", + output: + variants=norm_variants_dir / "variants.tsv.gz", + duplicates=qc_duplicate_vars_dir / "duplicates.tsv", + shell: + f"{preprocessing_cmd} add-variant-ids {{input}} {{output.variants}} {{output.duplicates}}" + + +rule create_parquet_variant_ids: + input: + norm_variants_dir / "variants_no_id.tsv.gz", + output: + variants=norm_variants_dir / "variants.parquet", + duplicates=qc_duplicate_vars_dir / "duplicates.parquet", + shell: + f"{preprocessing_cmd} add-variant-ids {{input}} {{output.variants}} {{output.duplicates}}" + + +rule extract_samples: + input: + vcf_files, + output: + norm_dir / "samples_chr.csv", + shell: + f"{load_bcftools} bcftools query --list-samples {{input}} > {{output}}" diff --git a/pipelines/preprocessing/qc.snakefile b/pipelines/preprocessing/qc.snakefile new file mode 100644 index 00000000..fe7d3cc6 --- /dev/null +++ b/pipelines/preprocessing/qc.snakefile @@ -0,0 +1,43 @@ + + +rule qc_allelic_imbalance: + input: + bcf_dir / "{vcf_stem}.bcf", + output: + qc_allelic_imbalance_dir / "{vcf_stem}.tsv.gz", + shell: + f"""{load_bcftools} bcftools query --format '%CHROM\t%POS\t%REF\t%ALT\n' --exclude 'COUNT(GT="het")=0 || (GT="het" & ((TYPE="snp" & (FORMAT/AD[*:1] / FORMAT/AD[*:0]) > 0.15) | (TYPE="indel" & (FORMAT/AD[*:1] / FORMAT/AD[*:0]) > 0.20)))' {{input}} | gzip > {{output}}""" + + +rule qc_varmiss: + input: + bcf_dir / "{vcf_stem}.bcf", + output: + qc_varmiss_dir / "{vcf_stem}.tsv.gz", + shell: + f'{load_bcftools} bcftools query --format "%CHROM\t%POS\t%REF\t%ALT\n" --include "F_MISSING >= 0.1" {{input}} | gzip > {{output}}' + + +rule qc_hwe: + input: + bcf_dir / "{vcf_stem}.bcf", + output: + qc_hwe_dir / "{vcf_stem}.tsv.gz", + shell: + f'{load_bcftools} bcftools +fill-tags --output-type u {{input}} -- --tags HWE | bcftools query --format "%CHROM\t%POS\t%REF\t%ALT\n" --include "INFO/HWE <= 1e-15" | gzip > {{output}}' + + +rule qc_read_depth: + input: + bcf_dir / "{vcf_stem}.bcf", + output: + qc_read_depth_dir / "{vcf_stem}.tsv.gz", + shell: + f"""{load_bcftools} bcftools query --format '[%CHROM\\t%POS\\t%REF\\t%ALT\\t%SAMPLE\\n]' --include '(GT!="RR" & GT!="mis" & TYPE="snp" & FORMAT/DP < 7) | (GT!="RR" & GT!="mis" & TYPE="indel" & FORMAT/DP < 10)' {{input}} | gzip > {{output}}""" + + +rule create_excluded_samples_dir: + output: + directory(qc_filtered_samples_dir), + shell: + "mkdir -p {output}" diff --git a/pipelines/run_training.snakefile b/pipelines/run_training.snakefile new file mode 100644 index 00000000..0e10d79e --- /dev/null +++ b/pipelines/run_training.snakefile @@ -0,0 +1,51 @@ +from pathlib import Path + +configfile: 'config.yaml' + +debug_flag = config.get('debug', False) +phenotypes = config['phenotypes'] +phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) + +n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 +n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 +n_trials = config['hyperparameter_optimization']['n_trials'] +n_bags = config['training']['n_bags'] if not debug_flag else 3 +n_repeats = config['n_repeats'] +debug = '--debug ' if debug_flag else '' +do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' +tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path("models") +n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) + +wildcard_constraints: + repeat="\d+", + trial="\d+", + +include: "training/config.snakefile" +include: "training/training_dataset.snakefile" +include: "training/train.snakefile" + +rule all: + input: + expand( model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', + bag=range(n_bags), repeat=range(n_repeats)), + model_path / "config.yaml" + +rule all_training_dataset: + input: + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes, repeat=range(n_repeats)) + +rule all_config: + input: + seed_genes = expand('{phenotype}/deeprvat/seed_genes.parquet', + phenotype=phenotypes), + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=phenotypes), + baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', + phenotype=phenotypes), \ No newline at end of file diff --git a/pipelines/seed_gene_discovery.snakefile b/pipelines/seed_gene_discovery.snakefile index 1003867a..7a93ac26 100644 --- a/pipelines/seed_gene_discovery.snakefile +++ b/pipelines/seed_gene_discovery.snakefile @@ -11,7 +11,9 @@ vtypes = config.get("variant_types", ["plof"]) ttypes = config.get("test_types", ["burden"]) rare_maf = config.get("rare_maf", 0.001) -n_chunks = config.get("n_chunks", 30) if not debug_flag else 2 + +n_chunks_missense = 15 +n_chunks_plof = 4 debug = "--debug " if debug_flag else "" persist_burdens = "--persist-burdens" if config.get("persist_burdens", False) else "" @@ -65,14 +67,14 @@ rule all_regression: ), -rule combine_regression_chunks: +rule combine_regression_chunks_plof: input: train=expand( - "{{phenotype}}/{{vtype}}/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", - chunk=range(n_chunks), + "{{phenotype}}/plof/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", + chunk=range(n_chunks_plof), ), output: - train="{phenotype}/{vtype}/{ttype}/results/burden_associations.parquet", + train="{phenotype}/plof/{ttype}/results/burden_associations.parquet", threads: 1 resources: mem_mb=2048, @@ -85,31 +87,96 @@ rule combine_regression_chunks: ] ) +rule combine_regression_chunks_missense: + input: + train=expand( + "{{phenotype}}/missense/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", + chunk=range(n_chunks_missense), + ), + output: + train="{phenotype}/missense/{ttype}/results/burden_associations.parquet", + threads: 1 + resources: + mem_mb=2048, + load=2000, + shell: + " && ".join( + [ + conda_check, + "seed_gene_pipeline combine-results " "{input.train} " "{output.train}", + ] + ) -rule all_regression_results: + +rule all_regression_results_plof: input: expand( - "{phenotype}/{vtype}/{ttype}/results/burden_associations_chunk{chunk}.parquet", + "{phenotype}/plof/{ttype}/results/burden_associations_chunk{chunk}.parquet", phenotype=phenotypes, vtype=vtypes, ttype=ttypes, - chunk=range(n_chunks), + chunk=range(n_chunks_plof), ), +rule all_regression_results_missense: + input: + expand( + "{phenotype}/missense/{ttype}/results/burden_associations_chunk{chunk}.parquet", + phenotype=phenotypes, + vtype=vtypes, + ttype=ttypes, + chunk=range(n_chunks_missense), + ), + +rule regress_plof: + input: + data="{phenotype}/plof/association_dataset_full.pkl", + dataset="{phenotype}/plof/association_dataset_pickled.pkl", + config="{phenotype}/plof/config.yaml", + output: + out_path=temp( + "{phenotype}/plof/{ttype}/results/burden_associations_chunk{chunk}.parquet" + ), + threads: 1 + priority: 30 + resources: + mem_mb = lambda wildcards, attempt: 20000 + 2000 * attempt, + load=8000, + # gpus = 1 + shell: + " && ".join( + [ + conda_check, + ( + "seed_gene_pipeline run-association " + + debug + + " --n-chunks " + + str(n_chunks_plof) + + " " + "--chunk {wildcards.chunk} " + "--dataset-file {input.dataset} " + "--data-file {input.data} " + persist_burdens + " " + " {input.config} " + "plof " + "{wildcards.ttype} " + "{output.out_path}" + ), + ] + ) -rule regress: +rule regress_missense: input: - data="{phenotype}/{vtype}/association_dataset_full.pkl", - dataset="{phenotype}/{vtype}/association_dataset_pickled.pkl", - config="{phenotype}/{vtype}/config.yaml", + data="{phenotype}/missense/association_dataset_full.pkl", + dataset="{phenotype}/missense/association_dataset_pickled.pkl", + config="{phenotype}/missense/config.yaml", output: out_path=temp( - "{phenotype}/{vtype}/{ttype}/results/burden_associations_chunk{chunk}.parquet" + "{phenotype}/missense/{ttype}/results/burden_associations_chunk{chunk}.parquet" ), - threads: 10 + threads: 1 priority: 30 resources: - mem_mb=24000, + mem_mb = lambda wildcards, attempt: 30000 + 6000 * attempt, load=8000, # gpus = 1 shell: @@ -120,13 +187,13 @@ rule regress: "seed_gene_pipeline run-association " + debug + " --n-chunks " - + str(n_chunks) + + str(n_chunks_missense) + " " "--chunk {wildcards.chunk} " "--dataset-file {input.dataset} " "--data-file {input.data} " + persist_burdens + " " " {input.config} " - "{wildcards.vtype} " + "missense " "{wildcards.ttype} " "{output.out_path}" ), @@ -194,6 +261,7 @@ rule config: "seed_gene_pipeline update-config " + "--phenotype {wildcards.phenotype} " + "--variant-type {wildcards.vtype} " + + "--maf-column MAF " + "--rare-maf " + "{params.rare_maf}" + " {input.config} " @@ -201,3 +269,5 @@ rule config: ), ] ) + + diff --git a/pipelines/training/config.snakefile b/pipelines/training/config.snakefile new file mode 100644 index 00000000..3c58a39d --- /dev/null +++ b/pipelines/training/config.snakefile @@ -0,0 +1,29 @@ + +rule config: + input: + config = 'config.yaml', + baseline = lambda wildcards: [ + str(Path(r['base']) / wildcards.phenotype / r['type'] / + 'eval/burden_associations.parquet') + for r in config['baseline_results'] + ] + output: + seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', + config = '{phenotype}/deeprvat/hpopt_config.yaml', + baseline = '{phenotype}/deeprvat/baseline_results.parquet', + threads: 1 + params: + baseline_results = lambda wildcards, input: ''.join([ + f'--baseline-results {b} ' + for b in input.baseline + ]) + shell: + ( + 'deeprvat_config update-config ' + '--phenotype {wildcards.phenotype} ' + '{params.baseline_results}' + '--baseline-results-out {output.baseline} ' + '--seed-genes-out {output.seed_genes} ' + '{input.config} ' + '{output.config}' + ) \ No newline at end of file diff --git a/pipelines/training/train.snakefile b/pipelines/training/train.snakefile new file mode 100644 index 00000000..c747fd1f --- /dev/null +++ b/pipelines/training/train.snakefile @@ -0,0 +1,64 @@ + +rule link_config: + input: + model_path / 'repeat_0/config.yaml' + output: + model_path / 'config.yaml' + threads: 1 + shell: + "ln -s repeat_0/config.yaml {output}" + + +rule best_training_run: + input: + expand(model_path / 'repeat_{{repeat}}/trial{trial_number}/config.yaml', + trial_number=range(n_trials)), + output: + checkpoints = expand(model_path / 'repeat_{{repeat}}/best/bag_{bag}.ckpt', + bag=range(n_bags)), + config = model_path / 'repeat_{repeat}/config.yaml' + threads: 1 + shell: + ( + 'deeprvat_train best-training-run ' + + debug + + '{model_path}/repeat_{wildcards.repeat} ' + '{model_path}/repeat_{wildcards.repeat}/best ' + '{model_path}/repeat_{wildcards.repeat}/hyperparameter_optimization.db ' + '{output.config}' + ) + +rule train: + input: + config = expand('{phenotype}/deeprvat/hpopt_config.yaml', + phenotype=training_phenotypes), + input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', + phenotype=training_phenotypes), + covariates = expand('{phenotype}/deeprvat/covariates.zarr', + phenotype=training_phenotypes), + y = expand('{phenotype}/deeprvat/y.zarr', + phenotype=training_phenotypes), + output: + expand(model_path / 'repeat_{repeat}/trial{trial_number}/config.yaml', + repeat=range(n_repeats), trial_number=range(n_trials)), + expand(model_path / 'repeat_{repeat}/trial{trial_number}/finished.tmp', + repeat=range(n_repeats), trial_number=range(n_trials)) + params: + phenotypes = " ".join( + [f"--phenotype {p} " + f"{p}/deeprvat/input_tensor.zarr " + f"{p}/deeprvat/covariates.zarr " + f"{p}/deeprvat/y.zarr" + for p in training_phenotypes]) + shell: + f"parallel --jobs {n_parallel_training_jobs} --halt now,fail=1 --results train_repeat{{{{1}}}}_trial{{{{2}}}}/ " + 'deeprvat_train train ' + + debug + + '--trial-id {{2}} ' + "{params.phenotypes} " + 'config.yaml ' + '{model_path}/repeat_{{1}}/trial{{2}} ' + '{model_path}/repeat_{{1}}/hyperparameter_optimization.db "&&" ' + 'touch {model_path}/repeat_{{1}}/trial{{2}}/finished.tmp ' + "::: " + " ".join(map(str, range(n_repeats))) + " " + "::: " + " ".join(map(str, range(n_trials))) diff --git a/pipelines/training/training_dataset.snakefile b/pipelines/training/training_dataset.snakefile new file mode 100644 index 00000000..66903b85 --- /dev/null +++ b/pipelines/training/training_dataset.snakefile @@ -0,0 +1,37 @@ + +rule training_dataset: + input: + config = '{phenotype}/deeprvat/hpopt_config.yaml', + training_dataset = '{phenotype}/deeprvat/training_dataset.pkl' + output: + input_tensor = directory('{phenotype}/deeprvat/input_tensor.zarr'), + covariates = directory('{phenotype}/deeprvat/covariates.zarr'), + y = directory('{phenotype}/deeprvat/y.zarr') + threads: 8 + priority: 50 + shell: + ( + 'deeprvat_train make-dataset ' + + debug + + '--compression-level ' + str(tensor_compression_level) + ' ' + '--training-dataset-file {input.training_dataset} ' + '{input.config} ' + '{output.input_tensor} ' + '{output.covariates} ' + '{output.y}' + ) + +rule training_dataset_pickle: + input: + '{phenotype}/deeprvat/hpopt_config.yaml' + output: + '{phenotype}/deeprvat/training_dataset.pkl' + threads: 1 + shell: + ( + 'deeprvat_train make-dataset ' + '--pickle-only ' + '--training-dataset-file {output} ' + '{input} ' + 'dummy dummy dummy' + ) \ No newline at end of file diff --git a/pipelines/training_association_testing.snakefile b/pipelines/training_association_testing.snakefile index 3270290a..60384eaf 100644 --- a/pipelines/training_association_testing.snakefile +++ b/pipelines/training_association_testing.snakefile @@ -5,6 +5,7 @@ configfile: 'config.yaml' debug_flag = config.get('debug', False) phenotypes = config['phenotypes'] phenotypes = list(phenotypes.keys()) if type(phenotypes) == dict else phenotypes +training_phenotypes = config["training"].get("phenotypes", phenotypes) n_burden_chunks = config.get('n_burden_chunks', 1) if not debug_flag else 2 n_regression_chunks = config.get('n_regression_chunks', 40) if not debug_flag else 2 @@ -14,11 +15,20 @@ n_repeats = config['n_repeats'] debug = '--debug ' if debug_flag else '' do_scoretest = '--do-scoretest ' if config.get('do_scoretest', False) else '' tensor_compression_level = config['training'].get('tensor_compression_level', 1) +model_path = Path("models") +n_parallel_training_jobs = config["training"].get("n_parallel_jobs", 1) wildcard_constraints: repeat="\d+", trial="\d+", +include: "training/config.snakefile" +include: "training/training_dataset.snakefile" +include: "training/train.snakefile" +include: "association_testing/association_dataset.snakefile" +include: "association_testing/burdens.snakefile" +include: "association_testing/regress_eval.snakefile" + rule all: input: expand("{phenotype}/deeprvat/eval/significant.parquet", @@ -26,69 +36,6 @@ rule all: expand("{phenotype}/deeprvat/eval/all_results.parquet", phenotype=phenotypes) -rule evaluate: - input: - associations = expand('{{phenotype}}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - repeat=range(n_repeats)), - config = '{phenotype}/deeprvat/hpopt_config.yaml', - output: - "{phenotype}/deeprvat/eval/significant.parquet", - "{phenotype}/deeprvat/eval/all_results.parquet" - threads: 1 - shell: - 'deeprvat_evaluate ' - + debug + - '--use-seed-genes ' - '--n-repeats {n_repeats} ' - '--correction-method FDR ' - '{input.associations} ' - '{input.config} ' - '{wildcards.phenotype}/deeprvat/eval' - -rule all_regression: - input: - expand('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - phenotype=phenotypes, type=['deeprvat'], repeat=range(n_repeats)), - -rule combine_regression_chunks: - input: - expand('{{phenotype}}/deeprvat/repeat_{{repeat}}/results/burden_associations_{chunk}.parquet', chunk=range(n_regression_chunks)), - output: - '{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations.parquet', - threads: 1 - shell: - 'deeprvat_associate combine-regression-results ' - '--model-name repeat_{wildcards.repeat} ' - '{input} ' - '{output}' - -rule regress: - input: - config = "{phenotype}/deeprvat/hpopt_config.yaml", - chunks = lambda wildcards: expand( - ('{{phenotype}}/deeprvat/burdens/chunk{chunk}.' + - ("finished" if wildcards.phenotype == phenotypes[0] else "linked")), - chunk=range(n_burden_chunks) - ), - phenotype_0_chunks = expand( - phenotypes[0] + '/deeprvat/burdens/chunk{chunk}.finished', - chunk=range(n_burden_chunks) - ), - output: - temp('{phenotype}/deeprvat/repeat_{repeat}/results/burden_associations_{chunk}.parquet'), - threads: 2 - shell: - 'deeprvat_associate regress ' - + debug + - '--chunk {wildcards.chunk} ' - '--n-chunks ' + str(n_regression_chunks) + ' ' - '--use-bias ' - '--repeat {wildcards.repeat} ' - + do_scoretest + - '{input.config} ' - '{wildcards.phenotype}/deeprvat/burdens ' #TODO make this w/o repeats - '{wildcards.phenotype}/deeprvat/repeat_{wildcards.repeat}/results' - rule all_burdens: input: [ @@ -98,209 +45,25 @@ rule all_burdens: for c in range(n_burden_chunks) ] -rule link_burdens: - priority: 1 - input: - checkpoints = lambda wildcards: [ - f'models/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = 'models/config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.linked' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - f'--link-burdens ../../../{phenotypes[0]}/deeprvat/burdens/burdens.zarr ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - -rule compute_burdens: - priority: 10 - input: - reversed = "models/reverse_finished.tmp", - checkpoints = lambda wildcards: [ - f'models/repeat_{repeat}/best/bag_{bag}.ckpt' - for repeat in range(n_repeats) for bag in range(n_bags) - ], - dataset = '{phenotype}/deeprvat/association_dataset.pkl', - data_config = '{phenotype}/deeprvat/hpopt_config.yaml', - model_config = 'models/config.yaml', - output: - '{phenotype}/deeprvat/burdens/chunk{chunk}.finished' - threads: 8 - shell: - ' && '.join([ - ('deeprvat_associate compute-burdens ' - + debug + - ' --n-chunks '+ str(n_burden_chunks) + ' ' - '--chunk {wildcards.chunk} ' - '--dataset-file {input.dataset} ' - '{input.data_config} ' - '{input.model_config} ' - '{input.checkpoints} ' - '{wildcards.phenotype}/deeprvat/burdens'), - 'touch {output}' - ]) - rule all_association_dataset: input: expand('{phenotype}/deeprvat/association_dataset.pkl', phenotype=phenotypes) -rule association_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/association_dataset.pkl' - threads: 4 - shell: - 'deeprvat_associate make-dataset ' - + debug + - '{input.config} ' - '{output}' - -rule reverse_models: - input: - checkpoints = expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', - bag=range(n_bags), repeat=range(n_repeats)), - model_config = 'models/config.yaml', - data_config = Path(phenotypes[0]) / "deeprvat/hpopt_config.yaml", - output: - "models/reverse_finished.tmp" - threads: 4 - shell: - " && ".join([ - ("deeprvat_associate reverse-models " - "{input.model_config} " - "{input.data_config} " - "{input.checkpoints}"), - "touch {output}" - ]) - rule all_training: input: - expand('models/repeat_{repeat}/best/bag_{bag}.ckpt', + expand(model_path / 'repeat_{repeat}/best/bag_{bag}.ckpt', bag=range(n_bags), repeat=range(n_repeats)), - "models/config.yaml" - -rule link_config: - input: - 'models/repeat_0/config.yaml' - output: - "models/config.yaml" - threads: 1 - shell: - "ln -s repeat_0/config.yaml {output}" - - -rule best_training_run: - input: - expand('models/repeat_{{repeat}}/trial{trial_number}/config.yaml', - trial_number=range(n_trials)), - output: - checkpoints = expand('models/repeat_{{repeat}}/best/bag_{bag}.ckpt', - bag=range(n_bags)), - config = 'models/repeat_{repeat}/config.yaml' - threads: 1 - shell: - ( - 'deeprvat_train best-training-run ' - + debug + - 'models/repeat_{wildcards.repeat} ' - 'models/repeat_{wildcards.repeat}/best ' - 'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db ' - '{output.config}' - ) - -rule train: - input: - config = expand('{phenotype}/deeprvat/hpopt_config.yaml', - phenotype=phenotypes), - input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', - phenotype=phenotypes), - covariates = expand('{phenotype}/deeprvat/covariates.zarr', - phenotype=phenotypes), - y = expand('{phenotype}/deeprvat/y.zarr', - phenotype=phenotypes), - output: - config = 'models/repeat_{repeat}/trial{trial_number}/config.yaml', - finished = 'models/repeat_{repeat}/trial{trial_number}/finished.tmp' - params: - phenotypes = " ".join( - [f"--phenotype {p} " - f"{p}/deeprvat/input_tensor.zarr " - f"{p}/deeprvat/covariates.zarr " - f"{p}/deeprvat/y.zarr" - for p in phenotypes]) - shell: - ' && '.join([ - 'deeprvat_train train ' - + debug + - '--trial-id {wildcards.trial_number} ' - "{params.phenotypes} " - 'config.yaml ' - 'models/repeat_{wildcards.repeat}/trial{wildcards.trial_number} ' - 'models/repeat_{wildcards.repeat}/hyperparameter_optimization.db', - 'touch {output.finished}' - ]) + model_path / "config.yaml" rule all_training_dataset: input: input_tensor = expand('{phenotype}/deeprvat/input_tensor.zarr', - phenotype=phenotypes, repeat=range(n_repeats)), + phenotype=training_phenotypes, repeat=range(n_repeats)), covariates = expand('{phenotype}/deeprvat/covariates.zarr', - phenotype=phenotypes, repeat=range(n_repeats)), + phenotype=training_phenotypes, repeat=range(n_repeats)), y = expand('{phenotype}/deeprvat/y.zarr', - phenotype=phenotypes, repeat=range(n_repeats)) - -rule training_dataset: - input: - config = '{phenotype}/deeprvat/hpopt_config.yaml', - training_dataset = '{phenotype}/deeprvat/training_dataset.pkl' - output: - input_tensor = directory('{phenotype}/deeprvat/input_tensor.zarr'), - covariates = directory('{phenotype}/deeprvat/covariates.zarr'), - y = directory('{phenotype}/deeprvat/y.zarr') - threads: 8 - priority: 50 - shell: - ( - 'deeprvat_train make-dataset ' - + debug + - '--compression-level ' + str(tensor_compression_level) + ' ' - '--training-dataset-file {input.training_dataset} ' - '{input.config} ' - '{output.input_tensor} ' - '{output.covariates} ' - '{output.y}' - ) - -rule training_dataset_pickle: - input: - '{phenotype}/deeprvat/hpopt_config.yaml' - output: - '{phenotype}/deeprvat/training_dataset.pkl' - threads: 1 - shell: - ( - 'deeprvat_train make-dataset ' - '--pickle-only ' - '--training-dataset-file {output} ' - '{input} ' - 'dummy dummy dummy' - ) + phenotype=training_phenotypes, repeat=range(n_repeats)) rule all_config: input: @@ -310,32 +73,3 @@ rule all_config: phenotype=phenotypes), baseline = expand('{phenotype}/deeprvat/baseline_results.parquet', phenotype=phenotypes), - -rule config: - input: - config = 'config.yaml', - baseline = lambda wildcards: [ - str(Path(r['base']) / wildcards.phenotype / r['type'] / - 'eval/burden_associations_testing.parquet') - for r in config['baseline_results'] - ] - output: - seed_genes = '{phenotype}/deeprvat/seed_genes.parquet', - config = '{phenotype}/deeprvat/hpopt_config.yaml', - baseline = '{phenotype}/deeprvat/baseline_results.parquet', - threads: 1 - params: - baseline_results = lambda wildcards, input: ''.join([ - f'--baseline-results {b} ' - for b in input.baseline - ]) - shell: - ( - 'deeprvat_config update-config ' - '--phenotype {wildcards.phenotype} ' - '{params.baseline_results}' - '--baseline-results-out {output.baseline} ' - '--seed-genes-out {output.seed_genes} ' - '{input.config} ' - '{output.config}' - ) diff --git a/tests/deeprvat/test_data/training/Cholesterol/deeprvat/covariates.zarr/.zarray b/tests/deeprvat/test_data/training/Cholesterol/deeprvat/covariates.zarr/.zarray new file mode 100644 index 00000000..38d199e7 --- /dev/null +++ b/tests/deeprvat/test_data/training/Cholesterol/deeprvat/covariates.zarr/.zarray @@ -0,0 +1,22 @@ +{ + "chunks": [ + 1000, + 22 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dtype": " 5" + verbose: True + low_memory: False + verbose: True + dataloader_config: + batch_size: 64 + num_workers: 8 + +data: + gt_file: example/genotypes.h5 + variant_file: example/variants.parquet + dataset_config: + min_common_af: + MAF: 0.01 + phenotype_file: example/phenotypes.parquet + y_transformation: quantile_transform + x_phenotypes: + - age + - genetic_sex + - genetic_PC_1 + - genetic_PC_2 + - genetic_PC_3 + - genetic_PC_4 + - genetic_PC_5 + - genetic_PC_6 + - genetic_PC_7 + - genetic_PC_8 + - genetic_PC_9 + - genetic_PC_10 + - genetic_PC_11 + - genetic_PC_12 + - genetic_PC_13 + - genetic_PC_14 + - genetic_PC_15 + - genetic_PC_16 + - genetic_PC_17 + - genetic_PC_18 + - genetic_PC_19 + - genetic_PC_20 + annotation_file: example/annotations.parquet + annotations: + - MAF + - MAF_MB + - CADD_PHRED + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + gene_file: example/protein_coding_genes.parquet + use_common_variants: False + use_rare_variants: True + rare_embedding: + type: PaddedAnnotations + config: + annotations: + - MAF_MB + - CADD_raw + - sift_score + - polyphen_score + - Consequence_splice_acceptor_variant + - Consequence_splice_donor_variant + - Consequence_stop_gained + - Consequence_frameshift_variant + - Consequence_stop_lost + - Consequence_start_lost + - Consequence_inframe_insertion + - Consequence_inframe_deletion + - Consequence_missense_variant + - Consequence_protein_altering_variant + - Consequence_splice_region_variant + - condel_score + - DeepSEA_PC_1 + - DeepSEA_PC_2 + - DeepSEA_PC_3 + - DeepSEA_PC_4 + - DeepSEA_PC_5 + - DeepSEA_PC_6 + - PrimateAI_score + - AbSplice_DNA + - DeepRipe_plus_QKI_lip_hg2 + - DeepRipe_plus_QKI_clip_k5 + - DeepRipe_plus_KHDRBS1_clip_k5 + - DeepRipe_plus_ELAVL1_parclip + - DeepRipe_plus_TARDBP_parclip + - DeepRipe_plus_HNRNPD_parclip + - DeepRipe_plus_MBNL1_parclip + - DeepRipe_plus_QKI_parclip + - SpliceAI_delta_score + thresholds: + MAF: "MAF < 1e-3" + CADD_PHRED: "CADD_PHRED > 5" + gene_file: example/protein_coding_genes.parquet + verbose: True + verbose: True + dataloader_config: + batch_size: 16 + num_workers: 10 diff --git a/tests/deeprvat/test_data/training/phenotypes.txt b/tests/deeprvat/test_data/training/phenotypes.txt new file mode 100644 index 00000000..693565ee --- /dev/null +++ b/tests/deeprvat/test_data/training/phenotypes.txt @@ -0,0 +1,5 @@ +Platelet_distribution_width +Cholesterol +Platelet_crit +Mean_platelet_thrombocyte_volume +HDL_cholesterol diff --git a/tests/deeprvat/test_train.py b/tests/deeprvat/test_train.py new file mode 100644 index 00000000..7fbcbbfa --- /dev/null +++ b/tests/deeprvat/test_train.py @@ -0,0 +1,211 @@ +import os +import sys +import logging +from deeprvat.data import DenseGTDataset +import yaml +from typing import Dict, Tuple +import pandas as pd +import itertools +from torch.utils.data import DataLoader +from tqdm import tqdm +import copy +import numpy as np +from pathlib import Path +import pytest +import torch +import zarr + +from deeprvat.deeprvat.train import make_dataset_, MultiphenoDataset + +logging.basicConfig( + format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s", + level="INFO", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + +# TODO: +# 1. Test cache_tensors +# 2. Test edge cases for data +# 3. Maybe fix expected data? +# 4. Test subset in data[p]["samples"] +# 5. Test entire script (loading from CLI) +# 6. Test "indices" element of dataset batches +# 7. Test val dataset +# 8. Test inputting temporary directory explicitly +# 9. Different min_variant_counts + +script_dir = Path(__file__).resolve().parent +tests_data_dir = script_dir / "test_data" / "training" +example_data_dir = script_dir.parent / "example" +test_config_file = tests_data_dir / "config.yaml" + +with open(tests_data_dir / "phenotypes.txt", "r") as f: + phenotypes = f.read().strip().split("\n") + +arrays = ("input_tensor", "covariates", "y") + + +def make_multipheno_data(): + data = { + p: { + a: torch.tensor(zarr.open(tests_data_dir / p / "deeprvat" / f"{a}.zarr")[:]) + for a in arrays[1:] + } + for p in phenotypes + } + for p in phenotypes: + data[p]["input_tensor_zarr"] = zarr.open( + tests_data_dir / p / "deeprvat/input_tensor.zarr" + ) + data[p]["input_tensor"] = data[p]["input_tensor_zarr"][:] + data[p]["samples"] = {"train": np.arange(data[p]["y"].shape[0])} + + return data + + +def subset_samples( + data: Dict[str, torch.Tensor], min_variant_count: int = 0 +) -> Tuple[np.ndarray, Dict[str, torch.Tensor]]: + data = copy.deepcopy(data) + + n_variant_mask = np.sum(np.any(data["input_tensor"], axis=(1, 2)), axis=1) >= 1 + nan_mask = ~np.isnan(data["y"].squeeze()) + mask = n_variant_mask & nan_mask + + for a in arrays: + data[a] = data[a][mask] + + return mask, data + + +def multiphenodataset_reference(data): + data = copy.deepcopy(data) + data = {p: {a: data[p][a].squeeze() for a in arrays} for p in phenotypes} + return data + + +def reconstruct_from_batches(dl: DataLoader): + array_lists = {p: {a: [] for a in arrays} for p in phenotypes} + for batch in tqdm(dl): + for p, data in batch.items(): + array_lists[p]["input_tensor"].append( + data["rare_variant_annotations"].numpy() + ) + array_lists[p]["covariates"].append(data["covariates"].numpy()) + array_lists[p]["y"].append(data["y"].numpy()) + + return { + p: {a: np.concatenate(array_lists[p][a]) for a in arrays} for p in phenotypes + } + + +@pytest.fixture +def multipheno_data(): + data = make_multipheno_data() + reference = multiphenodataset_reference(data) + return data, reference + + +@pytest.mark.parametrize( + "cache_tensors, batch_size", + list(itertools.product([False, True], [1, 13, 1024])), +) +def test_multiphenodataset(multipheno_data, cache_tensors: bool, batch_size: int): + data, reference = multipheno_data + dataset = MultiphenoDataset(data, batch_size, cache_tensors=cache_tensors) + dl = DataLoader(dataset, batch_size=None, num_workers=0) + reconstructed = reconstruct_from_batches(dl) + + for p in phenotypes: + for a in arrays: + assert np.allclose(reference[p][a], reconstructed[p][a]) + + +@pytest.mark.parametrize( + "phenotype, min_variant_count", + list(zip(phenotypes, [0, 1, 2])), +) +def test_make_dataset(phenotype: str, min_variant_count: int, tmp_path: Path): + # os.chdir(example_data_dir) + + with open(test_config_file, "r") as f: + config = yaml.safe_load(f) + + # Set phenotype and seed gene files in config + config["training_data"]["dataset_config"]["y_phenotypes"] = [phenotype] + seed_gene_file = str(tests_data_dir / phenotype / "deeprvat" / "seed_genes.parquet") + config["seed_gene_file"] = seed_gene_file + config["training_data"]["dataset_config"]["gene_file"] = seed_gene_file + config["training_data"]["dataset_config"]["rare_embedding"]["config"][ + "gene_file" + ] = seed_gene_file + config_file = tmp_path / "config.yaml" + with open(config_file, "w") as f: + yaml.dump(config, f) + + # This is the function we want to test + input_tensor_out_file = str(tmp_path / "input_tensor.zarr") + covariates_out_file = str(tmp_path / "covariates.zarr") + y_out_file = str(tmp_path / "y.zarr") + logger.info("Constructing test dataset") + test_ds = make_dataset_( + False, + False, + 1, + None, + config_file, + input_tensor_out_file, + covariates_out_file, + y_out_file, + ) + + # Load the data it output + test_data = {} + test_data["input_tensor"] = zarr.load(input_tensor_out_file) + test_data["covariates"] = zarr.load(covariates_out_file) + test_data["y"] = zarr.load(y_out_file) + + # Assert data shapes agree + assert test_data["input_tensor"].shape[0] == test_data["covariates"].shape[0] + assert test_data["input_tensor"].shape[0] == test_data["y"].shape[0] + + # Assert all of kept (up to index min_variant_count - 1) has some nonzero values + if min_variant_count > 0: + assert np.all( + np.any( + test_data["input_tensor"][:, :, :min_variant_count, :] != 0.0, + axis=(1, 3), + ) + ) + + # Load data in single batch as reference to check against make_dataset_ + logger.info("Constructing reference dataset") + reference_ds = DenseGTDataset( + gt_file=config["training_data"]["gt_file"], + variant_file=config["training_data"]["variant_file"], + split="", + skip_y_na=True, + **config["training_data"]["dataset_config"], + ) + reference_dl = DataLoader( + reference_ds, collate_fn=reference_ds.collate_fn, batch_size=len(reference_ds) + ) + reference_data = next(iter(reference_dl)) + reference_data = { + "input_tensor": reference_data["rare_variant_annotations"].numpy(), + "covariates": reference_data["x_phenotypes"].numpy(), + "y": reference_data["y"].numpy(), + } + + # Subset reference data + mask, reference_subset = subset_samples(reference_data, min_variant_count) + + # Assert all of dropped (beyond index min_variant_count - 1) is 0. + assert np.all( + reference_subset["input_tensor"][~mask, :, min_variant_count:, :] == 0.0 + ) + + for a in arrays: + # Compare make_dataset_ output to reference_subset + assert np.array_equal(test_data[a], reference_subset[a]) diff --git a/tests/test_data/preprocessing/add_variant_ids/add_variant_ids_parquet/expected/expected_variants.tsv.gz b/tests/preprocessing/test_data/add_variant_ids/add_variant_ids_parquet/expected/expected_variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/add_variant_ids/add_variant_ids_parquet/expected/expected_variants.tsv.gz rename to tests/preprocessing/test_data/add_variant_ids/add_variant_ids_parquet/expected/expected_variants.tsv.gz diff --git a/tests/test_data/preprocessing/add_variant_ids/add_variant_ids_parquet/input/variants_no_id.tsv.gz b/tests/preprocessing/test_data/add_variant_ids/add_variant_ids_parquet/input/variants_no_id.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/add_variant_ids/add_variant_ids_parquet/input/variants_no_id.tsv.gz rename to tests/preprocessing/test_data/add_variant_ids/add_variant_ids_parquet/input/variants_no_id.tsv.gz diff --git a/tests/test_data/preprocessing/add_variant_ids/add_variant_ids_tsv/expected/expected_variants.tsv.gz b/tests/preprocessing/test_data/add_variant_ids/add_variant_ids_tsv/expected/expected_variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/add_variant_ids/add_variant_ids_tsv/expected/expected_variants.tsv.gz rename to tests/preprocessing/test_data/add_variant_ids/add_variant_ids_tsv/expected/expected_variants.tsv.gz diff --git a/tests/test_data/preprocessing/add_variant_ids/add_variant_ids_tsv/input/variants_no_id.tsv.gz b/tests/preprocessing/test_data/add_variant_ids/add_variant_ids_tsv/input/variants_no_id.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/add_variant_ids/add_variant_ids_tsv/input/variants_no_id.tsv.gz rename to tests/preprocessing/test_data/add_variant_ids/add_variant_ids_tsv/input/variants_no_id.tsv.gz diff --git a/tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/expected/expected_data.npz b/tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/expected/expected_data.npz rename to tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/input/genotypes_chr1.h5 b/tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/input/genotypes_chr1.h5 similarity index 100% rename from tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/input/genotypes_chr1.h5 rename to tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/input/genotypes_chr1.h5 diff --git a/tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/input/genotypes_chr2.h5 b/tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/input/genotypes_chr2.h5 similarity index 100% rename from tests/test_data/preprocessing/combine_genotypes/combine_chr1_chr2/input/genotypes_chr2.h5 rename to tests/preprocessing/test_data/combine_genotypes/combine_chr1_chr2/input/genotypes_chr2.h5 diff --git a/tests/preprocessing/test_data/get_file_chr/chr1_sample.tsv b/tests/preprocessing/test_data/get_file_chr/chr1_sample.tsv new file mode 100644 index 00000000..5ab5ad6e --- /dev/null +++ b/tests/preprocessing/test_data/get_file_chr/chr1_sample.tsv @@ -0,0 +1,4 @@ +100103 chr1 T G 100103 1 +100103 chr1 T A 100096 1 +100103 chr1 T A 100103 1 +100103 chr1 T A 100104 1 diff --git a/tests/preprocessing/test_data/get_file_chr/chr2_sample.tsv b/tests/preprocessing/test_data/get_file_chr/chr2_sample.tsv new file mode 100644 index 00000000..e2e99ea8 --- /dev/null +++ b/tests/preprocessing/test_data/get_file_chr/chr2_sample.tsv @@ -0,0 +1,4 @@ +100103 chr2 T G 100103 1 +100103 chr2 T A 100096 1 +100103 chr2 T A 100103 1 +100103 chr2 T A 100104 1 diff --git a/tests/preprocessing/test_data/get_file_chr/no_chr_sample.tsv b/tests/preprocessing/test_data/get_file_chr/no_chr_sample.tsv new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/expected/expected_data.npz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/expected/expected_data.npz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr2/excluded_calls_untidy.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr2/excluded_calls_untidy.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr2/excluded_calls_untidy.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/calls/chr2/excluded_calls_untidy.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/samples/excluded_samples_untidy.csv b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/samples/excluded_samples_untidy.csv similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/samples/excluded_samples_untidy.csv rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/samples/excluded_samples_untidy.csv diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/variants/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/variants/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/variants/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/qc/variants/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/samples_chr.csv b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/samples_chr.csv rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/variants.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/filter_calls_variants_samples_minimal_split/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/expected/expected_data.npz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/expected/expected_data.npz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/samples_chr.csv b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/samples_chr.csv rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr1/input_c1_b2.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/sparse_gt/chr2/input_c2_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/variants.tsv.gz b/tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_and_combine_sparse_gt/no_filters_minimal_split/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_and_combine_sparse_gt/no_filters_minimal_split/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/expected/expected_data.npz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/expected/expected_data.npz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/qc/chr1/excluded_calls.tsv b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/qc/chr1/excluded_calls.tsv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/qc/chr1/excluded_calls.tsv rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/qc/chr1/excluded_calls.tsv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/samples_chr.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/samples_chr.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/variants.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_minimal/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_minimal/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/expected/expected_data.npz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/expected/expected_data.npz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/calls/chr1/excluded_calls_untidy.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/samples/excluded_samples_untidy.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/samples/excluded_samples_untidy.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/samples/excluded_samples_untidy.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/samples/excluded_samples_untidy.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/qc/variants/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/samples_chr.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/samples_chr.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/variants.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_calls_vars_samples_minimal/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_calls_vars_samples_minimal/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/expected/expected_data.npz b/tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/expected/expected_data.npz rename to tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/qc/excluded_samples.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/qc/excluded_samples.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/qc/excluded_samples.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/qc/excluded_samples.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/samples_chr.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/samples_chr.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/variants.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_samples_minimal/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_samples_minimal/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/expected/expected_data.npz b/tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/expected/expected_data.npz rename to tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/qc/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/qc/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/qc/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/qc/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/samples_chr.csv b/tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/samples_chr.csv rename to tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/variants.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/filter_variants_minimal/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/filter_variants_minimal/input/variants.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/expected/expected_data.npz b/tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/expected/expected_data.npz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/expected/expected_data.npz rename to tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/expected/expected_data.npz diff --git a/tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/samples_chr.csv b/tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/samples_chr.csv similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/samples_chr.csv rename to tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/samples_chr.csv diff --git a/tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/sparse_gt/chr1/input_c1_b1.tsv.gz diff --git a/tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/variants.tsv.gz b/tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/variants.tsv.gz similarity index 100% rename from tests/test_data/preprocessing/process_sparse_gt/no_filters_minimal/input/variants.tsv.gz rename to tests/preprocessing/test_data/process_sparse_gt/no_filters_minimal/input/variants.tsv.gz diff --git a/tests/test_preprocess.py b/tests/preprocessing/test_preprocess.py similarity index 99% rename from tests/test_preprocess.py rename to tests/preprocessing/test_preprocess.py index a150f279..314b3409 100644 --- a/tests/test_preprocess.py +++ b/tests/preprocessing/test_preprocess.py @@ -8,7 +8,7 @@ import pytest script_dir = Path(__file__).resolve().parent -tests_data_dir = script_dir / "test_data/preprocessing" +tests_data_dir = script_dir / "test_data" def load_h5_archive(h5_path):