Skip to content

Commit

Permalink
Updates to CCI modeling downstream funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhu8 committed Feb 21, 2024
1 parent e99379f commit 2a8e99f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 26 deletions.
23 changes: 17 additions & 6 deletions spateo/tools/CCI_effects_modeling/MuSIC.py
Original file line number Diff line number Diff line change
Expand Up @@ -3519,14 +3519,8 @@ def predict(

# Subtract 1 from predictions for predictions to account for the pseudocount from model setup:
y_pred -= 1
# if self.mod_type != "downstream":
# y_pred[y_pred < 1] = 0.0
# else:
y_pred[y_pred < 0] = 0.0

# thresh = 1.01 if self.normalize else 0
# y_pred[y_pred <= thresh] = 0.0

y_pred = pd.DataFrame(y_pred, index=self.sample_names, columns=[target])
all_y_pred = pd.concat([all_y_pred, y_pred], axis=1)
return all_y_pred
Expand Down Expand Up @@ -3785,6 +3779,23 @@ def return_outputs(
# object, save back to file path:
all_outputs = pd.concat([betas, standard_errors], axis=1)
all_outputs.to_csv(os.path.join(parent_dir, file))
else:
# Same processing as for subsampling, but without the subsampling:
if self.mod_type in ["receptor", "ligand", "downstream"]:
mask_matrix = (self.adata[:, target].X != 0).toarray().astype(int)
betas *= mask_matrix
standard_errors *= mask_matrix
mask_df = (self.X_df != 0).astype(int)
mask_df = mask_df.loc[:, [g for g in mask_df.columns if g in feat_sub]]
for col in betas.columns:
if col.replace("b_", "") not in mask_df.columns:
mask_df[col] = 0
# Make sure the columns are in the same order:
betas_columns = [col.replace("b_", "") for col in betas.columns]
mask_df = mask_df.reindex(columns=betas_columns)
mask_matrix = mask_df.values
betas *= mask_matrix
standard_errors *= mask_matrix

# Save coefficients and standard errors to dictionary:
all_coeffs[target] = betas
Expand Down
90 changes: 70 additions & 20 deletions spateo/tools/CCI_effects_modeling/MuSIC_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7035,13 +7035,15 @@ def deg_effect_heatmap(
self,
target_subset: Optional[List[str]] = None,
target_type: Literal["ligand", "receptor", "target_gene", "tf_target"] = "target_gene",
to_plot: Literal["proportion", "specificity"] = "proportion",
interaction_subset: Optional[List[str]] = None,
fontsize: Union[None, int] = None,
figsize: Union[None, Tuple[float, float]] = None,
cmap: str = "magma",
lower_proportion_threshold: float = 0.1,
order_interactions: bool = False,
order_targets: bool = False,
remove_all_zero_rows_and_cols: bool = False,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = {},
save_df: bool = False,
Expand All @@ -7057,6 +7059,9 @@ def deg_effect_heatmap(
visualize, e.g. ["Tubb1a", "Tubb1b"]. If not given, will default to all targets.
target_type: Type of target gene to visualize. Must be one of "ligand", "receptor", or "target_gene".
Defaults to "target_gene". Used to specify where to search for the target genes to process.
to_plot: Two options, "proportion" or "specificity": for proportion, plot the proportion of cells
expressing the target that are affected by each interaction. For specificity, take the proportion of
cells affected by each interaction for which the interaction is predicted to affect a specific target.
interaction_subset: Optional, can be used to specify subset of interactions (transcription factors,
L:R pairs, etc.) to visualize, e.g. ["Sox2", "Irx3"]. If not given, will default to all TFs,
L:R pairs, etc.
Expand All @@ -7068,6 +7073,7 @@ def deg_effect_heatmap(
order_interactions: Whether to hierarchically sort the y-axis/interactions (transcription factors,
L:R pairs, etc.).
order_targets: Whether to hierarchically sort the x-axis/targets (ligands, receptors, target genes)
remove_all_zero_rows_and_cols: Whether to remove all-zero rows and columns from the heatmap.
save_show_or_return: Whether to save, show or return the figure.
If "both", it will save and plot the figure at the same time. If "all", the figure will be saved,
displayed and the associated axis and other object will be return.
Expand All @@ -7078,9 +7084,17 @@ def deg_effect_heatmap(
keys according to your needs.
save_df: Set True to save the metric dataframe in the end
"""
if save_df:
output_folder = os.path.join(os.path.dirname(self.output_path), "analyses")
if not os.path.exists(output_folder):
os.makedirs(output_folder)

base_name = os.path.basename(self.adata_path)
adata_id = os.path.splitext(base_name)[0]

if order_interactions or order_targets:
from scipy.cluster.hierarchy import dendrogram, leaves_list, linkage
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import leaves_list, linkage
from scipy.spatial.distance import pdist

if fontsize is None:
fontsize = rcParams.get("font.size")
Expand Down Expand Up @@ -7135,19 +7149,23 @@ def deg_effect_heatmap(
# Check whether targets are specified to be ligands, receptors, or target genes:
if target_type == "ligand":
all_coeffs = self.downstream_model_ligand_coeffs
dm = self.downstream_model_ligand_design_matrix
all_feature_names = [
f.replace("regulator_", "") for f in self.downstream_model_ligand_design_matrix.columns
]
elif target_type == "receptor":
all_coeffs = self.downstream_model_receptor_coeffs
dm = self.downstream_model_receptor_design_matrix
all_feature_names = [
f.replace("regulator_", "") for f in self.downstream_model_receptor_design_matrix.columns
]
elif target_type == "tf_target":
all_coeffs = self.downstream_model_target_coeffs
dm = self.downstream_model_target_design_matrix
all_feature_names = [f for f in self.downstream_model_target_design_matrix.columns]
elif target_type == "target_gene":
all_coeffs = self.coeffs
dm = self.design_matrix
all_feature_names = [f for f in self.design_matrix.columns]
else:
raise ValueError(
Expand All @@ -7164,7 +7182,7 @@ def deg_effect_heatmap(
os.makedirs(figure_folder)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=figsize)
all_proportions = pd.DataFrame()
all_plot_values = pd.DataFrame()

for target in all_coeffs.keys():
all_coeffs_target = all_coeffs[target]
Expand All @@ -7177,34 +7195,56 @@ def deg_effect_heatmap(
all_feature_names = [f for f in all_feature_names if f in all_coeffs_target.columns]
all_coeffs_target = all_coeffs_target[all_feature_names]

nz_cells = np.where(self.adata[:, target].X.toarray() > 0)[0]
if len(nz_cells) > 0:
proportions = (all_coeffs_target.iloc[nz_cells] != 0).mean()
all_proportions[target] = proportions
if to_plot == "proportion":
nz_cells = np.where(self.adata[:, target].X.toarray() > 0)[0]
print(len(nz_cells))
for interaction in all_feature_names:
print(target)
print(interaction)
proportions = (all_coeffs_target.iloc[nz_cells][interaction] != 0).mean()
print((all_coeffs_target.iloc[nz_cells][interaction] != 0).sum())
all_plot_values.loc[interaction, target] = proportions
elif to_plot == "specificity":
for interaction in all_feature_names:
if target_type in ["ligand", "receptor", "tf_target"]:
all_cells_affected = dm[dm[f"regulator_{interaction}"] > 0]
else:
all_cells_affected = dm[dm[interaction] > 0]
specificity = (all_coeffs_target.loc[all_cells_affected.index, interaction] != 0).mean()
all_plot_values.loc[interaction, target] = specificity
all_plot_values.index = [replace_col_with_collagens(f) for f in all_plot_values.index]
all_plot_values.index = [replace_hla_with_hlas(f) for f in all_plot_values.index]

if order_interactions:
interaction_dist_mat = pdist(all_proportions.values, metric="euclidean")
interaction_dist_mat = pdist(all_plot_values.values, metric="euclidean")
interaction_linkage = linkage(interaction_dist_mat, method="ward")
interaction_order = leaves_list(interaction_linkage)
all_proportions = all_proportions.iloc[interaction_order]
all_proportions = all_plot_values.iloc[interaction_order]

if order_targets:
target_dist_mat = pdist(all_proportions.T.values, metric="euclidean")
target_dist_mat = pdist(all_plot_values.T.values, metric="euclidean")
target_linkage = linkage(target_dist_mat, method="ward")
target_order = leaves_list(target_linkage)
all_proportions = all_proportions.T.iloc[target_order].T
all_plot_values = all_plot_values.T.iloc[target_order].T

if remove_all_zero_rows_and_cols:
# Keep rows/columns with any values above the threshold
all_plot_values = all_plot_values.loc[
(all_plot_values > lower_proportion_threshold).any(axis=1),
(all_plot_values > lower_proportion_threshold).any(axis=0),
]

thickness = 0.3 * figsize[0] / 10
mask = np.abs(all_proportions) < lower_proportion_threshold
thickness = 0.5 * figsize[0] / 10
mask = np.abs(all_plot_values) < lower_proportion_threshold
m = sns.heatmap(
all_proportions,
all_plot_values,
square=True,
linecolor="grey",
linewidths=thickness,
cbar_kws={"label": "Proportion", "location": "top", "pad": 0.05},
cbar_kws={"label": to_plot.title(), "location": "top", "pad": 0.05},
cmap=cmap,
vmin=0,
vmax=all_proportions.max().max(),
vmax=all_plot_values.max().max(),
mask=mask,
ax=ax,
)
Expand All @@ -7216,9 +7256,9 @@ def deg_effect_heatmap(

# Adjust colorbar label font size
cbar = m.collections[0].colorbar
cbar.set_label("Proportion", fontsize=fontsize * 2, labelpad=10)
cbar.set_label(to_plot.title(), fontsize=fontsize * 1.5, labelpad=10)
# Adjust colorbar tick font size
cbar.ax.tick_params(labelsize=fontsize * 2.0)
cbar.ax.tick_params(labelsize=fontsize * 1.5)
cbar.ax.set_aspect(0.02)

if target_type == "ligand":
Expand All @@ -7233,16 +7273,26 @@ def deg_effect_heatmap(
ax.set_ylabel(y_label, fontsize=fontsize * 2)
ax.tick_params(axis="x", labelsize=fontsize, rotation=90)
ax.tick_params(axis="y", labelsize=fontsize)
ax.set_title(f"Proportion of target-expressing cells affected by each {id}", fontsize=fontsize * 2.25, pad=20)
title_fontsize = fontsize * 2.25 if figsize[0] > 20 else fontsize * 2
title = (
f"Proportion of target-expressing cells \naffected by each {id}"
if to_plot == "proportion"
else f"Specificity of each {id}"
)
ax.set_title(title, fontsize=title_fontsize, pad=20)
prefix = "heatmap"

if save_df:
all_proportions.to_csv(
all_plot_values.to_csv(
os.path.join(
output_folder,
f"{prefix}_{adata_id}_proportion_affected_by_interaction.csv",
)
)
self.logger.info(
f"Saving {to_plot} dataframe to "
f"{os.path.join(output_folder,f'{prefix}_{adata_id}_proportion_affected_by_interaction.csv')}"
)

if save_show_or_return in ["save", "both", "all"]:
save_kwargs["ext"] = "png"
Expand Down

0 comments on commit 2a8e99f

Please sign in to comment.