Skip to content

Commit

Permalink
update backbone clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
Yao-14 committed Sep 13, 2023
1 parent eb31d08 commit d40edbf
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 1 deletion.
1 change: 1 addition & 0 deletions spateo/tdr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ElPiGraph_method,
PrinCurve_method,
SimplePPT_method,
backbone_scc,
construct_backbone,
map_gene_to_backbone,
map_points_to_backbone,
Expand Down
2 changes: 1 addition & 1 deletion spateo/tdr/models/models_backbone/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .backbone import construct_backbone, update_backbone
from .backbone import backbone_scc, construct_backbone, update_backbone
from .backbone_methods import ElPiGraph_method, PrinCurve_method, SimplePPT_method
from .backbone_utils import map_gene_to_backbone, map_points_to_backbone
82 changes: 82 additions & 0 deletions spateo/tdr/models/models_backbone/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
except ImportError:
from typing_extensions import Literal

import anndata as ad
import numpy as np
import pandas as pd
from anndata import AnnData
from pyvista import PolyData, UnstructuredGrid
from scipy.sparse import issparse
from scipy.spatial.distance import cdist


Expand Down Expand Up @@ -149,3 +152,82 @@ def update_backbone(

updated_backbone.point_data[key_added] = np.arange(0, updated_backbone.n_points, 1)
return updated_backbone


def backbone_scc(
adata: AnnData,
backbone: PolyData,
genes: Optional[list] = None,
adata_nodes_key: str = "backbone_nodes",
backbone_nodes_key: str = "updated_nodes",
key_added: Optional[str] = "backbone_scc",
layer: Optional[str] = None,
e_neigh: int = 10,
s_neigh: int = 6,
cluster_method: Literal["leiden", "louvain"] = "leiden",
resolution: Optional[float] = None,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Spatially constrained clustering (scc) along the backbone.
Args:
adata: The anndata object.
backbone: The backbone model.
genes: The list of genes that will be used to subset the data for clustering. If ``genes = None``, all genes will be used.
adata_nodes_key: The key that corresponds to the nodes in the adata.
backbone_nodes_key: The key that corresponds to the nodes in the backbone.
key_added: adata.obs key under which to add the cluster labels.
layer: The layer that will be used to retrieve data for dimension reduction and clustering. If ``layer = None``, ``.X`` is used.
e_neigh: the number of nearest neighbor in gene expression space.
s_neigh: the number of nearest neighbor in physical space.
cluster_method: the method that will be used to cluster the cells.
resolution: the resolution parameter of the louvain clustering algorithm.
inplace: Whether to copy adata or modify it inplace.
Returns:
An ``AnnData`` object is updated/copied with the ``key_added`` in the ``.obs`` attribute, storing the clustering results.
"""
import dynamo as dyn
from dynamo.tools.utils import fetch_X_data

from ....tools import scc

adata = adata if inplace else adata.copy()
if "pp" not in adata.uns.keys():
adata.uns["pp"] = {}
genes, X_data = fetch_X_data(adata, genes, layer)
X_data = X_data.A if issparse(X_data) else X_data
X_data = pd.DataFrame(X_data, columns=genes)
X_data[adata_nodes_key] = adata.obs[adata_nodes_key].values
X_data = pd.DataFrame(X_data.groupby(by=adata_nodes_key).mean())
backbone_nodes = X_data.index

X_spatial = pd.DataFrame(backbone.points, index=backbone.point_data[backbone_nodes_key])
X_spatial = X_spatial.loc[backbone_nodes, :].values

backbone_adata = ad.AnnData(
X=X_data.values,
var=pd.DataFrame(index=X_data.columns),
obs=pd.DataFrame(backbone_nodes, columns=[adata_nodes_key]),
obsm={"spatial": X_spatial},
uns={"__type": "UMI", "pp": {}},
)

dyn.pp.normalize(backbone_adata)
dyn.pp.log1p(backbone_adata)
backbone_adata.obsm["X_backbone"] = backbone_adata.X
scc(
backbone_adata,
spatial_key="spatial",
pca_key="X_backbone",
e_neigh=e_neigh,
s_neigh=s_neigh,
resolution=resolution,
key_added="scc",
cluster_method=cluster_method,
)

cluster_dict = {i: c for i, c in zip(backbone_adata.obs[adata_nodes_key], backbone_adata.obs["scc"])}
adata.obs[key_added] = adata.obs[adata_nodes_key].map(lambda x: cluster_dict[x])
return None if inplace else adata

0 comments on commit d40edbf

Please sign in to comment.