Skip to content

Commit

Permalink
update backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
Yao-14 committed Sep 11, 2023
1 parent 079e0ff commit eb31d08
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 7 deletions.
2 changes: 1 addition & 1 deletion spateo/plotting/static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .contour import spatial_domains
from .dotplot import dotplot
from .geo import geo
from .glm import glm_fit
from .glm import glm_fit, glm_heatmap
from .interactions import ligrec, plot_connections
from .lisa import lisa, lisa_quantiles
from .polarity import *
Expand Down
95 changes: 94 additions & 1 deletion spateo/plotting/static/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData

Expand Down Expand Up @@ -38,7 +39,7 @@ def glm_fit(
**kwargs,
):
"""
Visualize the glm_degs result.
Plot the glm_degs result in a scatterplot.
Args:
adata: An Anndata object contain glm_degs result in ``.uns[glm_key]``.
Expand Down Expand Up @@ -137,3 +138,95 @@ def glm_fit(
return_all=False,
return_all_list=None,
)


def glm_heatmap(
adata: AnnData,
genes: Optional[Union[str, list]] = None,
feature_x: str = None,
feature_y: str = "expression",
glm_key: str = "glm_degs",
lowess_smooth: bool = True,
frac: float = 0.2,
robust: bool = True,
colormap: str = "vlag",
figsize: tuple = (6, 6),
background_color: str = "white",
show_legend: bool = True,
save_show_or_return: Literal["save", "show", "return", "both", "all"] = "show",
save_kwargs: Optional[dict] = None,
**kwargs,
):
"""
Plot the glm_degs result in a heatmap.
Args:
adata: An Anndata object contain glm_degs result in ``.uns[glm_key]``.
genes: A gene name or a list of genes that will be used to plot.
feature_x: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the independent variables, such as ``'torsion'``, etc.
feature_y: The key in ``.uns[glm_key]['correlation'][gene]`` that corresponds to the dependent variables, such as ``'expression'``, etc.
glm_key: The key in ``.uns`` that corresponds to the glm_degs result.
lowess_smooth: If True, use statsmodels to estimate a nonparametric lowess model (locally weighted linear regression).
frac: Between 0 and 1. The fraction of the data used when estimating each y-value.
robust: If True and vmin or vmax are absent, the colormap range is computed with robust quantiles instead of the extreme values.
colormap: The name of a matplotlib colormap.
figsize: The width and height of figure.
background_color: The background color of the figure.
show_legend: Whether to show the legend.
save_show_or_return: If ``'both'``, it will save and plot the figure at the same time.
If ``'all'``, the figure will be saved, displayed and the associated axis and other object will be return.
save_kwargs: A dictionary that will be passed to the save_fig function.
By default, it is an empty dictionary and the save_fig function will use the ``{"path": None, "prefix": 'scatter',
"dpi": None, "ext": 'pdf', "transparent": True, "close": True, "verbose": True}`` as its parameters.
Otherwise, you can provide a dictionary that properly modify those keys according to your needs.
**kwargs: Additional parameters that will be passed into the ``seaborn.heatmap`` function.
"""
assert not (feature_x is None), "``feature_x`` cannot be None."
assert not (feature_y is None), "``feature_y`` cannot be None."
assert (
glm_key in adata.uns
), f"``glm_key`` does not exist in adata.uns, please replace ``glm_key`` or run st.tl.glm_degs(key_added={glm_key})."

genes = list(adata.uns[glm_key]["glm_result"].index) if genes is None else genes
genes = list(genes) if isinstance(genes, list) else [genes]

genes_data = []
for g in genes:
gene_data = adata.uns[glm_key]["correlation"][g].copy()
gene_data.sort_values(by=feature_x, ascending=True, axis=0, inplace=True)
gene_data = gene_data.loc[:, [feature_x, feature_y]]
data = pd.DataFrame(gene_data.groupby(by=feature_x)[feature_y].mean())
if lowess_smooth:
import statsmodels.api as sm

data = pd.DataFrame(sm.nonparametric.lowess(exog=data.index, endog=data[feature_y], frac=frac))[1]
genes_data.append(data)
genes_data = pd.concat(genes_data, axis=1)
genes_data.fillna(value=0, inplace=True)
genes_data.columns = genes
genes_data = genes_data.T

max_sort = np.argsort(np.argmax(genes_data.values, axis=1))
genes_data = genes_data.iloc[max_sort]

fig, ax = plt.subplots(figsize=figsize)
sns.heatmap(genes_data, cmap=colormap, robust=robust, ax=ax, **kwargs)
ax.set_xlabel(feature_x)
ax.set_ylabel(feature_y)

plt.tight_layout(pad=1)
return save_return_show_fig_utils(
save_show_or_return=save_show_or_return,
show_legend=show_legend,
background=background_color,
prefix="glm_degs",
save_kwargs=save_kwargs,
total_panels=len(genes),
fig=fig,
axes=ax,
return_all=False,
return_all_list=None,
)
1 change: 1 addition & 0 deletions spateo/tdr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
construct_backbone,
map_gene_to_backbone,
map_points_to_backbone,
update_backbone,
)
from .models_individual import (
construct_cells,
Expand Down
2 changes: 1 addition & 1 deletion spateo/tdr/models/models_backbone/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .backbone import construct_backbone
from .backbone import construct_backbone, update_backbone
from .backbone_methods import ElPiGraph_method, PrinCurve_method, SimplePPT_method
from .backbone_utils import map_gene_to_backbone, map_points_to_backbone
75 changes: 75 additions & 0 deletions spateo/tdr/models/models_backbone/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing_extensions import Literal

import numpy as np
from anndata import AnnData
from pyvista import PolyData, UnstructuredGrid
from scipy.spatial.distance import cdist

Expand Down Expand Up @@ -74,3 +75,77 @@ def construct_backbone(
backbone_length = cdist(XA=np.asarray(s_points), XB=np.asarray(e_points), metric="euclidean").diagonal().sum()

return backbone_model, backbone_length, plot_cmap


def update_backbone(
backbone: PolyData,
nodes_key: str = "nodes",
key_added: str = "updated_nodes",
select_nodes: Optional[Union[list, np.ndarray]] = None,
interactive: Optional[bool] = True,
model_size: Union[float, list] = 8.0,
colormap: str = "Spectral",
) -> Union[PolyData, UnstructuredGrid]:
"""
Update the bakcbone through interaction or input of selected nodes.
Args:
backbone: The backbone model.
nodes_key: The key that corresponds to the coordinates of the nodes in the backbone.
key_added: The key under which to add the labels of new nodes.
select_nodes: Nodes that need to be retained.
interactive: Whether to delete useless nodes interactively. When ``interactive`` is True, ``select_nodes`` is invalid.
model_size: Thickness of backbone. When ``interactive`` is False, ``model_size`` is invalid.
colormap: Colormap of backbone. When ``interactive`` is False, ``colormap`` is invalid.
Returns:
updated_backbone: The updated backbone model.
"""
model = backbone.copy()
if interactive is True:
from ...widgets.clip import _interactive_rectangle_clip
from ...widgets.utils import _interactive_plotter

p = _interactive_plotter()
p.add_point_labels(
model,
labels=nodes_key,
font_size=18,
font_family="arial",
text_color="white",
shape_color="black",
always_visible=True,
)

picked_models, picking_r_list = [], []
if f"{nodes_key}_rgba" in model.array_names:
p.add_mesh(
model,
scalars=f"{nodes_key}_rgba",
rgba=True,
style="wireframe",
render_lines_as_tubes=True,
line_width=model_size,
)
else:
p.add_mesh(
model,
scalars=nodes_key,
style="wireframe",
render_lines_as_tubes=True,
line_width=model_size,
cmap=colormap,
)
_interactive_rectangle_clip(
plotter=p,
model=model,
picking_list=picked_models,
picking_r_list=picking_r_list,
)
p.show(cpos="iso")
updated_backbone = picking_r_list[0]
else:
updated_backbone = model.extract_cells(np.isin(np.asarray(model.point_data[nodes_key]), select_nodes))

updated_backbone.point_data[key_added] = np.arange(0, updated_backbone.n_points, 1)
return updated_backbone
28 changes: 24 additions & 4 deletions spateo/tdr/models/models_backbone/backbone_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from keras.layers import Dense, Input
from keras.models import Model

from .backbone_utils import sort_nodes_of_curve

#####################################################################
# Principal curves algorithm #
# ================================================================= #
Expand Down Expand Up @@ -186,10 +188,21 @@ def ElPiGraph_method(
else:
raise ValueError("`topology` value is wrong." "\nAvailable `topology` are: `'tree'`, `'circle'`, `'curve'`.")

nodes = elpi_tree[0]["NodePositions"] # ['AllNodePositions'][k]
matrix_edges_weights = elpi_tree[0]["ElasticMatrix"] # ['AllElasticMatrices'][k]
matrix_edges_weights = np.triu(matrix_edges_weights, 1)
edges = np.array(np.nonzero(matrix_edges_weights), dtype=int).transpose()
nodes = elpi_tree[0]["NodePositions"]
edges = np.asarray(elpi_tree[0]["Edges"][0])

if str(topology).lower() in ["curve", "circle"]:
unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True)
started_node_indices = [v for c, v in zip(occurrence_count, unique_values) if c == 1]
started_node = nodes[started_node_indices[0]] if len(started_node_indices) != 0 else nodes[0]

nodes = sort_nodes_of_curve(nodes, started_node)
if str(topology).lower() == "curve":
edges = np.c_[np.arange(0, len(nodes) - 1, 1).reshape(-1, 1), np.arange(1, len(nodes), 1).reshape(-1, 1)]
else:
edges = np.c_[
np.arange(0, len(nodes), 1).reshape(-1, 1), np.asarray(list(range(1, len(nodes))) + [0]).reshape(-1, 1)
]

return nodes, edges

Expand Down Expand Up @@ -320,5 +333,12 @@ def PrinCurve_method(
n_nodes = nodes.shape[0]
edges = np.asarray([np.arange(0, n_nodes, 1), np.arange(1, n_nodes + 1, 1)]).T
edges[-1, 1] = n_nodes - 1
"""
unique_values, occurrence_count = np.unique(edges.flatten(), return_counts=True)
started_node_indices = [v for c, v in zip(occurrence_count, unique_values) if c == 1]
started_node = nodes[started_node_indices[0]] if len(started_node_indices) != 0 else nodes[0]
sorted_nodes = sort_nodes_of_curve(nodes, started_node)
sorted_edges = np.c_[np.arange(0, len(nodes) - 1, 1).reshape(-1, 1), np.arange(1, len(nodes), 1).reshape(-1, 1)]
"""
return nodes, edges
20 changes: 20 additions & 0 deletions spateo/tdr/models/models_backbone/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ def map_gene_to_backbone(
tree.point_data[sub_key] = tree_data[sub_key].values

return tree if not inplace else None


def _euclidean_distance(N1, N2):
temp = np.asarray(N1) - np.asarray(N2)
euclid_dist = np.sqrt(np.dot(temp.T, temp))
return euclid_dist


def sort_nodes_of_curve(nodes, started_node):
current_node = tuple(started_node)
remaining_nodes = [tuple(node) for node in nodes]

sorted_nodes = []
while remaining_nodes:
closest_node = min(remaining_nodes, key=lambda x: _euclidean_distance(current_node, x))
sorted_nodes.append(closest_node)
remaining_nodes.remove(closest_node)
current_node = closest_node
sorted_nodes = np.asarray([list(sn) for sn in sorted_nodes])
return sorted_nodes

0 comments on commit eb31d08

Please sign in to comment.