diff --git a/deeprvat/data/dense_gt.py b/deeprvat/data/dense_gt.py index 16670651..c02521d5 100644 --- a/deeprvat/data/dense_gt.py +++ b/deeprvat/data/dense_gt.py @@ -409,7 +409,18 @@ def transform_data(self): self.phenotype_df[col] = rng.permutation( self.phenotype_df[col].to_numpy() ) - + if len(self.y_phenotypes) > 0: + unique_y_val = self.phenotype_df[self.y_phenotypes[0]].unique() + n_unique_y_val = np.count_nonzero(~np.isnan(unique_y_val)) + logger.info(f"unique y values {unique_y_val}") + logger.info(n_unique_y_val) + else: + n_unique_y_val = 0 + if n_unique_y_val == 2: + logger.warning( + "Not applying y transformation because y only has two values and seems to be binary" + ) + self.y_transformation = None if self.y_transformation is not None: if self.y_transformation == "standardize": logger.debug(" Standardizing target phenotype") @@ -425,6 +436,8 @@ def transform_data(self): ) else: raise ValueError(f"Unknown y_transformation: {self.y_transformation}") + else: + logger.warning("Not transforming phenotype") def setup_annotations( self, diff --git a/deeprvat/deeprvat/associate.py b/deeprvat/deeprvat/associate.py index dc3463f3..a9d29228 100644 --- a/deeprvat/deeprvat/associate.py +++ b/deeprvat/deeprvat/associate.py @@ -496,8 +496,11 @@ def regress_on_gene_scoretest(gene: str, burdens: np.ndarray, model_score): f"gene {gene}, p-value: {pv}, using saddle instead." ) pv = model_score.pv_alt_model(burdens, method="saddle") - - beta = model_score.coef(burdens)["beta"][0, 0] + # beta only for linear models + try: + beta = model_score.coef(burdens)["beta"][0, 0] + except: + beta = None genes_params_pvalues = ([], [], []) genes_params_pvalues[0].append(gene) @@ -576,7 +579,12 @@ def regress_( logger.info(f"X shape: {X.shape}, Y shape: {y.shape}") # compute null_model for score test - model_score = scoretest.ScoretestNoK(y, X) + if len(np.unique(y)) == 2: + logger.info("Fitting binary model since only found two distinct y values") + model_score = scoretest.ScoretestLogit(y, X) + else: + logger.info("Fitting linear model") + model_score = scoretest.ScoretestNoK(y, X) genes_betas_pvals = [ regress_on_gene_scoretest(gene, burdens[mask, i], model_score) for i, gene in tqdm( diff --git a/deeprvat/seed_gene_discovery/config.yaml b/deeprvat/seed_gene_discovery/config.yaml index 444d0635..f4a5eba0 100644 --- a/deeprvat/seed_gene_discovery/config.yaml +++ b/deeprvat/seed_gene_discovery/config.yaml @@ -20,7 +20,20 @@ phenotypes: # - Platelet_crit # - Platelet_distribution_width # - Red_blood_cell_erythrocyte_count - +# - Body_mass_index_BMI +# - Glucose +# - Vitamin_D +# - Albumin +# - Total_protein +# - Cystatin_C +# - Gamma_glutamyltransferase +# - Alkaline_phosphatase +# - Creatinine +# - Whole_body_fat_free_mass +# - Forced_expiratory_volume_in_1_second_FEV1 +# - Glycated_haemoglobin_HbA1c +# - WHR_Body_mass_index_BMI_corrected + variant_types: - missense - plof @@ -42,7 +55,7 @@ test_config: neglect_homozygous: False collapse_method: sum #collapsing method for burde var_weight_function: beta_maf - + min_mac: 10 variant_file: variants.parquet data: @@ -99,3 +112,4 @@ data: num_workers: 10 #batch_size: 20 + diff --git a/deeprvat/seed_gene_discovery/seed_gene_discovery.py b/deeprvat/seed_gene_discovery/seed_gene_discovery.py index 28d3e160..bd9c6781 100644 --- a/deeprvat/seed_gene_discovery/seed_gene_discovery.py +++ b/deeprvat/seed_gene_discovery/seed_gene_discovery.py @@ -7,6 +7,8 @@ import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +import copy + import click import numpy as np @@ -18,7 +20,7 @@ from tqdm import tqdm from deeprvat.data import DenseGTDataset -from seak.scoretest import ScoretestNoK +from seak import scoretest logging.basicConfig( format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s", @@ -38,6 +40,14 @@ def replace_in_array(arr, old_val, new_val): return np.where(arr == old_val, new_val, arr) +def get_caf(G): + # get the cumulative allele frequency + ac = G.sum(axis=0) # allele count of each variant + af = ac / (G.shape[0] * 2) # allele frequency of each variant + caf = af.sum() + return caf + + # return mask def save_burdens(GW_list, GW_full_list, split, chunk, out_dir): burdens_path = Path(f"{out_dir}/burdens") @@ -178,7 +188,11 @@ def get_anno( def call_score(GV, null_model_score, pval_dict, test_type): # score test # p-value for the score-test + start_time = time.time() pv = null_model_score.pv_alt_model(GV) + end_time = time.time() + time_diff = end_time - start_time + pval_dict["time"] = time_diff logger.info(f"p-value: {pv}") if pv < 0.0: logger.warning( @@ -195,10 +209,15 @@ def call_score(GV, null_model_score, pval_dict, test_type): if pv < 1e-3 and test_type == "burden": logger.info("Computing regression coefficient") # if gene is quite significant get the regression coefficient + SE - beta = null_model_score.coef(GV) - logger.info(f"Regression coefficient: {beta}") - pval_dict["beta"] = beta["beta"][0, 0] - pval_dict["betaSd"] = np.sqrt(beta["var_beta"][0, 0]) + # only works for quantitative traits + try: + beta = null_model_score.coef(GV) + logger.info(f"Regression coefficient: {beta}") + pval_dict["beta"] = beta["beta"][0, 0] + pval_dict["betaSd"] = np.sqrt(beta["var_beta"][0, 0]) + except: + pval_dict["beta"] = None + pval_dict["betaSd"] = None return pval_dict @@ -207,13 +226,14 @@ def test_gene( G_full: spmatrix, gene: int, grouped_annotations: pd.DataFrame, - dataset: DenseGTDataset, + Y, weight_cols: List[str], - null_model_score: ScoretestNoK, + null_model_score: scoretest.ScoretestNoK, test_config: Dict, var_type, test_type, maf_col, + min_mac, ) -> Dict[str, Any]: # Find variants present in gene # Convert sparse genotype to CSC @@ -232,11 +252,16 @@ def test_gene( # GET expected allele count (EAC) as in Karczewski et al. 2022/Genebass vars_per_sample = np.sum(G, axis=1) samples_with_variant = vars_per_sample[vars_per_sample > 0].shape[0] - EAC = np.sum(vars_per_sample) + if len(np.unique(Y)) == 2: + n_cases = (Y > 0).sum() + else: + n_cases = Y.shape[0] + EAC = get_caf(G) * n_cases pval_dict = {} pval_dict["EAC"] = EAC + pval_dict["n_cases"] = n_cases pval_dict["gene"] = gene pval_dict["pval"] = np.nan pval_dict["EAC_filtered"] = np.nan @@ -247,11 +272,11 @@ def test_gene( pval_dict["time"] = np.nan var_weight_function = test_config.get("var_weight_function", "sift_polyphen") + max_n_markers = test_config.get("max_n_markers", 5000) + # skips genes with more than max_n_markers qualifying variants logger.info(f"Using function {var_weight_function} for variant weighting") - # keep backwards compatibility - ( weights, _, @@ -272,12 +297,12 @@ def test_gene( f"Number of variants after thresholding using threshold {variant_weight_th}: {len(pos)}" ) pval_dict["n_QV"] = len(pos) - - if len(pos) > 0: + pval_dict["markers_after_mac_collapsing"] = len(pos) + if (len(pos) > 0) & (len(pos) < max_n_markers): G_f = G[:, pos] - EAC_filtered = np.sum(np.sum(G_f, axis=1)) + EAC_filtered = EAC = get_caf(G_f) * n_cases pval_dict["EAC_filtered"] = EAC_filtered - + MAC = G_f.sum(axis=0) count = G_f[G_f == 2].shape[0] # confirm that variants we include are rare variants @@ -303,11 +328,28 @@ def test_gene( pval_dict["n_cluster"] = GW.shape[1] ### COLLAPSE kernel if doing burden test - + collapse_ultra_rare = True if test_type == "skat": logger.info("Running Skat test") - GW = GW - + if collapse_ultra_rare: + logger.info(f"Max Collapsing variants with MAC <= {min_mac}") + MAC_mask = MAC <= min_mac + if MAC_mask.sum() > 0: + logger.info(f"Number of collapsed positions: {MAC_mask.sum()}") + GW_collapse = copy.deepcopy(GW) + GW_collapse = GW_collapse[:, MAC_mask].max(axis=1).reshape(-1, 1) + GW = GW[:, ~MAC_mask] + GW = np.hstack((GW_collapse, GW)) + logger.info(f"GW shape {GW.shape}") + else: + logger.info( + f"No ultra rare variants to collapse ({MAC_mask.sum()})" + ) + GW = GW + else: + GW = GW + + pval_dict["markers_after_mac_collapsing"] = GW.shape[1] if test_type == "burden": collapse_method = test_config.get("collapse_method", "binary") logger.info(f"Running burden test with collapsing method {collapse_method}") @@ -335,14 +377,18 @@ def run_association_( ) -> pd.DataFrame: # initialize the null models # ScoretestNoK automatically adds a bias column if not present - null_model_score = ScoretestNoK(Y, X) + if len(np.unique(Y)) == 2: + print("Fitting binary model since only found two distinct y values") + null_model_score = scoretest.ScoretestLogit(Y, X) + else: + null_model_score = scoretest.ScoretestNoK(Y, X) stats = [] GW_list = {} GW_full_list = {} time_list_inner = {} weight_cols = config.get("weight_cols", []) logger.info(f"Testing with this config: {config['test_config']}") - + min_mac = config["test_config"].get("min_mac", 0) # Get column with minor allele frequency annotations = config["data"]["dataset_config"]["annotations"] maf_col = [ @@ -360,13 +406,14 @@ def run_association_( G_full, gene, grouped_annotations, - dataset, + Y, weight_cols, null_model_score, config["test_config"], var_type, test_type, maf_col, + min_mac, ) if persist_burdens: GW_list[gene] = GW @@ -421,7 +468,7 @@ def update_config( simulated_phenotype_file: str, variant_type: Optional[str], rare_maf: Optional[float], - maf_column: Optional[str], + maf_column: str, new_config_file: str, ): with open(old_config_file) as f: @@ -431,6 +478,7 @@ def update_config( config["data"]["dataset_config"][ "sim_phenotype_file" ] = simulated_phenotype_file + logger.info(f"Reading MAF column from column {maf_column}") if phenotype is not None: config["data"]["dataset_config"]["y_phenotypes"] = [phenotype] @@ -645,9 +693,10 @@ def run_association( exploded_annotations = ( dataset.annotation_df.query("id in @all_variants") .explode("gene_ids") + .reset_index() .drop_duplicates() - ) # row can be duplicated if a variant is assigned to a gene multiple times - + .set_index("id") + ) grouped_annotations = exploded_annotations.groupby("gene_ids") gene_ids = pd.read_parquet(dataset.gene_file, columns=["id"])["id"].to_list() gene_ids = list( diff --git a/pipelines/seed_gene_discovery.snakefile b/pipelines/seed_gene_discovery.snakefile index 1003867a..7a93ac26 100644 --- a/pipelines/seed_gene_discovery.snakefile +++ b/pipelines/seed_gene_discovery.snakefile @@ -11,7 +11,9 @@ vtypes = config.get("variant_types", ["plof"]) ttypes = config.get("test_types", ["burden"]) rare_maf = config.get("rare_maf", 0.001) -n_chunks = config.get("n_chunks", 30) if not debug_flag else 2 + +n_chunks_missense = 15 +n_chunks_plof = 4 debug = "--debug " if debug_flag else "" persist_burdens = "--persist-burdens" if config.get("persist_burdens", False) else "" @@ -65,14 +67,14 @@ rule all_regression: ), -rule combine_regression_chunks: +rule combine_regression_chunks_plof: input: train=expand( - "{{phenotype}}/{{vtype}}/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", - chunk=range(n_chunks), + "{{phenotype}}/plof/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", + chunk=range(n_chunks_plof), ), output: - train="{phenotype}/{vtype}/{ttype}/results/burden_associations.parquet", + train="{phenotype}/plof/{ttype}/results/burden_associations.parquet", threads: 1 resources: mem_mb=2048, @@ -85,31 +87,96 @@ rule combine_regression_chunks: ] ) +rule combine_regression_chunks_missense: + input: + train=expand( + "{{phenotype}}/missense/{{ttype}}/results/burden_associations_chunk{chunk}.parquet", + chunk=range(n_chunks_missense), + ), + output: + train="{phenotype}/missense/{ttype}/results/burden_associations.parquet", + threads: 1 + resources: + mem_mb=2048, + load=2000, + shell: + " && ".join( + [ + conda_check, + "seed_gene_pipeline combine-results " "{input.train} " "{output.train}", + ] + ) -rule all_regression_results: + +rule all_regression_results_plof: input: expand( - "{phenotype}/{vtype}/{ttype}/results/burden_associations_chunk{chunk}.parquet", + "{phenotype}/plof/{ttype}/results/burden_associations_chunk{chunk}.parquet", phenotype=phenotypes, vtype=vtypes, ttype=ttypes, - chunk=range(n_chunks), + chunk=range(n_chunks_plof), ), +rule all_regression_results_missense: + input: + expand( + "{phenotype}/missense/{ttype}/results/burden_associations_chunk{chunk}.parquet", + phenotype=phenotypes, + vtype=vtypes, + ttype=ttypes, + chunk=range(n_chunks_missense), + ), + +rule regress_plof: + input: + data="{phenotype}/plof/association_dataset_full.pkl", + dataset="{phenotype}/plof/association_dataset_pickled.pkl", + config="{phenotype}/plof/config.yaml", + output: + out_path=temp( + "{phenotype}/plof/{ttype}/results/burden_associations_chunk{chunk}.parquet" + ), + threads: 1 + priority: 30 + resources: + mem_mb = lambda wildcards, attempt: 20000 + 2000 * attempt, + load=8000, + # gpus = 1 + shell: + " && ".join( + [ + conda_check, + ( + "seed_gene_pipeline run-association " + + debug + + " --n-chunks " + + str(n_chunks_plof) + + " " + "--chunk {wildcards.chunk} " + "--dataset-file {input.dataset} " + "--data-file {input.data} " + persist_burdens + " " + " {input.config} " + "plof " + "{wildcards.ttype} " + "{output.out_path}" + ), + ] + ) -rule regress: +rule regress_missense: input: - data="{phenotype}/{vtype}/association_dataset_full.pkl", - dataset="{phenotype}/{vtype}/association_dataset_pickled.pkl", - config="{phenotype}/{vtype}/config.yaml", + data="{phenotype}/missense/association_dataset_full.pkl", + dataset="{phenotype}/missense/association_dataset_pickled.pkl", + config="{phenotype}/missense/config.yaml", output: out_path=temp( - "{phenotype}/{vtype}/{ttype}/results/burden_associations_chunk{chunk}.parquet" + "{phenotype}/missense/{ttype}/results/burden_associations_chunk{chunk}.parquet" ), - threads: 10 + threads: 1 priority: 30 resources: - mem_mb=24000, + mem_mb = lambda wildcards, attempt: 30000 + 6000 * attempt, load=8000, # gpus = 1 shell: @@ -120,13 +187,13 @@ rule regress: "seed_gene_pipeline run-association " + debug + " --n-chunks " - + str(n_chunks) + + str(n_chunks_missense) + " " "--chunk {wildcards.chunk} " "--dataset-file {input.dataset} " "--data-file {input.data} " + persist_burdens + " " " {input.config} " - "{wildcards.vtype} " + "missense " "{wildcards.ttype} " "{output.out_path}" ), @@ -194,6 +261,7 @@ rule config: "seed_gene_pipeline update-config " + "--phenotype {wildcards.phenotype} " + "--variant-type {wildcards.vtype} " + + "--maf-column MAF " + "--rare-maf " + "{params.rare_maf}" + " {input.config} " @@ -201,3 +269,5 @@ rule config: ), ] ) + +