From 2a8e99ff47e69f52d7f134c8d52b76c403a52c7d Mon Sep 17 00:00:00 2001 From: "Daniel Y. Zhu" Date: Wed, 21 Feb 2024 14:39:52 -0500 Subject: [PATCH] Updates to CCI modeling downstream funcs --- spateo/tools/CCI_effects_modeling/MuSIC.py | 23 +++-- .../CCI_effects_modeling/MuSIC_downstream.py | 90 ++++++++++++++----- 2 files changed, 87 insertions(+), 26 deletions(-) diff --git a/spateo/tools/CCI_effects_modeling/MuSIC.py b/spateo/tools/CCI_effects_modeling/MuSIC.py index a12edffc..4e96a410 100644 --- a/spateo/tools/CCI_effects_modeling/MuSIC.py +++ b/spateo/tools/CCI_effects_modeling/MuSIC.py @@ -3519,14 +3519,8 @@ def predict( # Subtract 1 from predictions for predictions to account for the pseudocount from model setup: y_pred -= 1 - # if self.mod_type != "downstream": - # y_pred[y_pred < 1] = 0.0 - # else: y_pred[y_pred < 0] = 0.0 - # thresh = 1.01 if self.normalize else 0 - # y_pred[y_pred <= thresh] = 0.0 - y_pred = pd.DataFrame(y_pred, index=self.sample_names, columns=[target]) all_y_pred = pd.concat([all_y_pred, y_pred], axis=1) return all_y_pred @@ -3785,6 +3779,23 @@ def return_outputs( # object, save back to file path: all_outputs = pd.concat([betas, standard_errors], axis=1) all_outputs.to_csv(os.path.join(parent_dir, file)) + else: + # Same processing as for subsampling, but without the subsampling: + if self.mod_type in ["receptor", "ligand", "downstream"]: + mask_matrix = (self.adata[:, target].X != 0).toarray().astype(int) + betas *= mask_matrix + standard_errors *= mask_matrix + mask_df = (self.X_df != 0).astype(int) + mask_df = mask_df.loc[:, [g for g in mask_df.columns if g in feat_sub]] + for col in betas.columns: + if col.replace("b_", "") not in mask_df.columns: + mask_df[col] = 0 + # Make sure the columns are in the same order: + betas_columns = [col.replace("b_", "") for col in betas.columns] + mask_df = mask_df.reindex(columns=betas_columns) + mask_matrix = mask_df.values + betas *= mask_matrix + standard_errors *= mask_matrix # Save coefficients and standard errors to dictionary: all_coeffs[target] = betas diff --git a/spateo/tools/CCI_effects_modeling/MuSIC_downstream.py b/spateo/tools/CCI_effects_modeling/MuSIC_downstream.py index b3d21cac..1e4de531 100644 --- a/spateo/tools/CCI_effects_modeling/MuSIC_downstream.py +++ b/spateo/tools/CCI_effects_modeling/MuSIC_downstream.py @@ -7035,6 +7035,7 @@ def deg_effect_heatmap( self, target_subset: Optional[List[str]] = None, target_type: Literal["ligand", "receptor", "target_gene", "tf_target"] = "target_gene", + to_plot: Literal["proportion", "specificity"] = "proportion", interaction_subset: Optional[List[str]] = None, fontsize: Union[None, int] = None, figsize: Union[None, Tuple[float, float]] = None, @@ -7042,6 +7043,7 @@ def deg_effect_heatmap( lower_proportion_threshold: float = 0.1, order_interactions: bool = False, order_targets: bool = False, + remove_all_zero_rows_and_cols: bool = False, save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show", save_kwargs: Optional[dict] = {}, save_df: bool = False, @@ -7057,6 +7059,9 @@ def deg_effect_heatmap( visualize, e.g. ["Tubb1a", "Tubb1b"]. If not given, will default to all targets. target_type: Type of target gene to visualize. Must be one of "ligand", "receptor", or "target_gene". Defaults to "target_gene". Used to specify where to search for the target genes to process. + to_plot: Two options, "proportion" or "specificity": for proportion, plot the proportion of cells + expressing the target that are affected by each interaction. For specificity, take the proportion of + cells affected by each interaction for which the interaction is predicted to affect a specific target. interaction_subset: Optional, can be used to specify subset of interactions (transcription factors, L:R pairs, etc.) to visualize, e.g. ["Sox2", "Irx3"]. If not given, will default to all TFs, L:R pairs, etc. @@ -7068,6 +7073,7 @@ def deg_effect_heatmap( order_interactions: Whether to hierarchically sort the y-axis/interactions (transcription factors, L:R pairs, etc.). order_targets: Whether to hierarchically sort the x-axis/targets (ligands, receptors, target genes) + remove_all_zero_rows_and_cols: Whether to remove all-zero rows and columns from the heatmap. save_show_or_return: Whether to save, show or return the figure. If "both", it will save and plot the figure at the same time. If "all", the figure will be saved, displayed and the associated axis and other object will be return. @@ -7078,9 +7084,17 @@ def deg_effect_heatmap( keys according to your needs. save_df: Set True to save the metric dataframe in the end """ + if save_df: + output_folder = os.path.join(os.path.dirname(self.output_path), "analyses") + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + base_name = os.path.basename(self.adata_path) + adata_id = os.path.splitext(base_name)[0] + if order_interactions or order_targets: - from scipy.cluster.hierarchy import dendrogram, leaves_list, linkage - from scipy.spatial.distance import pdist, squareform + from scipy.cluster.hierarchy import leaves_list, linkage + from scipy.spatial.distance import pdist if fontsize is None: fontsize = rcParams.get("font.size") @@ -7135,19 +7149,23 @@ def deg_effect_heatmap( # Check whether targets are specified to be ligands, receptors, or target genes: if target_type == "ligand": all_coeffs = self.downstream_model_ligand_coeffs + dm = self.downstream_model_ligand_design_matrix all_feature_names = [ f.replace("regulator_", "") for f in self.downstream_model_ligand_design_matrix.columns ] elif target_type == "receptor": all_coeffs = self.downstream_model_receptor_coeffs + dm = self.downstream_model_receptor_design_matrix all_feature_names = [ f.replace("regulator_", "") for f in self.downstream_model_receptor_design_matrix.columns ] elif target_type == "tf_target": all_coeffs = self.downstream_model_target_coeffs + dm = self.downstream_model_target_design_matrix all_feature_names = [f for f in self.downstream_model_target_design_matrix.columns] elif target_type == "target_gene": all_coeffs = self.coeffs + dm = self.design_matrix all_feature_names = [f for f in self.design_matrix.columns] else: raise ValueError( @@ -7164,7 +7182,7 @@ def deg_effect_heatmap( os.makedirs(figure_folder) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize) - all_proportions = pd.DataFrame() + all_plot_values = pd.DataFrame() for target in all_coeffs.keys(): all_coeffs_target = all_coeffs[target] @@ -7177,34 +7195,56 @@ def deg_effect_heatmap( all_feature_names = [f for f in all_feature_names if f in all_coeffs_target.columns] all_coeffs_target = all_coeffs_target[all_feature_names] - nz_cells = np.where(self.adata[:, target].X.toarray() > 0)[0] - if len(nz_cells) > 0: - proportions = (all_coeffs_target.iloc[nz_cells] != 0).mean() - all_proportions[target] = proportions + if to_plot == "proportion": + nz_cells = np.where(self.adata[:, target].X.toarray() > 0)[0] + print(len(nz_cells)) + for interaction in all_feature_names: + print(target) + print(interaction) + proportions = (all_coeffs_target.iloc[nz_cells][interaction] != 0).mean() + print((all_coeffs_target.iloc[nz_cells][interaction] != 0).sum()) + all_plot_values.loc[interaction, target] = proportions + elif to_plot == "specificity": + for interaction in all_feature_names: + if target_type in ["ligand", "receptor", "tf_target"]: + all_cells_affected = dm[dm[f"regulator_{interaction}"] > 0] + else: + all_cells_affected = dm[dm[interaction] > 0] + specificity = (all_coeffs_target.loc[all_cells_affected.index, interaction] != 0).mean() + all_plot_values.loc[interaction, target] = specificity + all_plot_values.index = [replace_col_with_collagens(f) for f in all_plot_values.index] + all_plot_values.index = [replace_hla_with_hlas(f) for f in all_plot_values.index] if order_interactions: - interaction_dist_mat = pdist(all_proportions.values, metric="euclidean") + interaction_dist_mat = pdist(all_plot_values.values, metric="euclidean") interaction_linkage = linkage(interaction_dist_mat, method="ward") interaction_order = leaves_list(interaction_linkage) - all_proportions = all_proportions.iloc[interaction_order] + all_proportions = all_plot_values.iloc[interaction_order] if order_targets: - target_dist_mat = pdist(all_proportions.T.values, metric="euclidean") + target_dist_mat = pdist(all_plot_values.T.values, metric="euclidean") target_linkage = linkage(target_dist_mat, method="ward") target_order = leaves_list(target_linkage) - all_proportions = all_proportions.T.iloc[target_order].T + all_plot_values = all_plot_values.T.iloc[target_order].T + + if remove_all_zero_rows_and_cols: + # Keep rows/columns with any values above the threshold + all_plot_values = all_plot_values.loc[ + (all_plot_values > lower_proportion_threshold).any(axis=1), + (all_plot_values > lower_proportion_threshold).any(axis=0), + ] - thickness = 0.3 * figsize[0] / 10 - mask = np.abs(all_proportions) < lower_proportion_threshold + thickness = 0.5 * figsize[0] / 10 + mask = np.abs(all_plot_values) < lower_proportion_threshold m = sns.heatmap( - all_proportions, + all_plot_values, square=True, linecolor="grey", linewidths=thickness, - cbar_kws={"label": "Proportion", "location": "top", "pad": 0.05}, + cbar_kws={"label": to_plot.title(), "location": "top", "pad": 0.05}, cmap=cmap, vmin=0, - vmax=all_proportions.max().max(), + vmax=all_plot_values.max().max(), mask=mask, ax=ax, ) @@ -7216,9 +7256,9 @@ def deg_effect_heatmap( # Adjust colorbar label font size cbar = m.collections[0].colorbar - cbar.set_label("Proportion", fontsize=fontsize * 2, labelpad=10) + cbar.set_label(to_plot.title(), fontsize=fontsize * 1.5, labelpad=10) # Adjust colorbar tick font size - cbar.ax.tick_params(labelsize=fontsize * 2.0) + cbar.ax.tick_params(labelsize=fontsize * 1.5) cbar.ax.set_aspect(0.02) if target_type == "ligand": @@ -7233,16 +7273,26 @@ def deg_effect_heatmap( ax.set_ylabel(y_label, fontsize=fontsize * 2) ax.tick_params(axis="x", labelsize=fontsize, rotation=90) ax.tick_params(axis="y", labelsize=fontsize) - ax.set_title(f"Proportion of target-expressing cells affected by each {id}", fontsize=fontsize * 2.25, pad=20) + title_fontsize = fontsize * 2.25 if figsize[0] > 20 else fontsize * 2 + title = ( + f"Proportion of target-expressing cells \naffected by each {id}" + if to_plot == "proportion" + else f"Specificity of each {id}" + ) + ax.set_title(title, fontsize=title_fontsize, pad=20) prefix = "heatmap" if save_df: - all_proportions.to_csv( + all_plot_values.to_csv( os.path.join( output_folder, f"{prefix}_{adata_id}_proportion_affected_by_interaction.csv", ) ) + self.logger.info( + f"Saving {to_plot} dataframe to " + f"{os.path.join(output_folder,f'{prefix}_{adata_id}_proportion_affected_by_interaction.csv')}" + ) if save_show_or_return in ["save", "both", "all"]: save_kwargs["ext"] = "png"