Skip to content

Commit

Permalink
Merge pull request #255 from Sichao25/restore
Browse files Browse the repository at this point in the history
Restore contour.py
  • Loading branch information
Xiaojieqiu authored Jul 25, 2024
2 parents e61b98d + 5baf1b0 commit 35bfda5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions spateo/plotting/static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .align import multi_slices
from .bbs import delaunay, polygon
from .colorlabel import color_label
from .contour import spatial_domains
from .dotplot import dotplot
from .geo import geo
from .glm import glm_fit, glm_heatmap
Expand Down
68 changes: 68 additions & 0 deletions spateo/plotting/static/contour.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Written by @Jinerhal, adapted by @Xiaojieqiu.
"""
from typing import Dict, Optional, Tuple, Union

import cv2
import numpy as np
from anndata import AnnData

from ...configuration import SKM
from .utils import save_return_show_fig_utils


@SKM.check_adata_is_type(SKM.ADATA_UMI_TYPE, "adata")
def spatial_domains(
adata: AnnData,
bin_size: Optional[int] = None,
spatial_key: str = "spatial",
label_key: str = "cluster_img_label",
plot_size=(3, 3),
save_img="spatial_domains.png",
):
"""Generate an image with contours of each spatial domains.
Args:
adata: The adata object used to create the image.
bin_size: The size of the binning. Default to None.
spatial_key: The key name of the spatial coordinates. Default to "spatial".
label_key: The key name of the image label values. Default to "cluster_img_label".
plot_size: figsize for showing the image.
save_img: path to saving image file.
"""
import matplotlib.pyplot as plt
from numpngw import write_png

label_list = np.unique(adata.obs[label_key])
labels = np.zeros(len(adata))
for i in range(len(label_list)):
labels[adata.obs[label_key] == label_list[i]] = i + 1

if bin_size is None:
bin_size = adata.uns["bin_size"]

label_img = np.zeros(
(
int(max(adata.obsm[spatial_key][:, 0] // bin_size)) + 1,
int(max(adata.obsm[spatial_key][:, 1] // bin_size)) + 1,
)
)
for i in range(len(adata)):
label_img[
int(adata.obsm[spatial_key][i, 0] // bin_size), int(adata.obsm[spatial_key][i, 1] // bin_size)
] = labels[i]

contour_img = label_img.copy()
contour_img[:, :] = 255
for i in np.unique(label_img):
if i == 0:
continue
label_img_gray = np.where(label_img == i, 0, 1).astype("uint8")
_, thresh = cv2.threshold(label_img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
contour, _ = cv2.findContours(thresh, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
contour_img = cv2.drawContours(contour_img, contour[:], -1, 0.5, 1)

fig = plt.figure()
fig.set_size_inches(plot_size[0], plot_size[1])
plt.imshow(contour_img, cmap="tab20", origin="lower")

write_png(save_img, contour_img.astype("uint8"))

0 comments on commit 35bfda5

Please sign in to comment.