Skip to content

Commit

Permalink
update three_dims_plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Yao-14 committed Aug 21, 2023
1 parent 41c33be commit b183480
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
13 changes: 0 additions & 13 deletions spateo/plotting/static/three_d_plot/morphometrics_plots.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Optional, Union

import matplotlib as mpl
import numpy as np
import pandas as pd
from anndata import AnnData
from matplotlib.colors import LinearSegmentedColormap
from pyvista import MultiBlock, PolyData, UnstructuredGrid

from ....tdr import add_model_labels, collect_models
Expand All @@ -16,15 +14,6 @@
from typing_extensions import Literal


def _get_default_cmap():
if "default_cmap" not in mpl.colormaps():
colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"]
nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

mpl.colormaps.register(LinearSegmentedColormap.from_list("default_cmap", list(zip(nodes, colors))))
return "default_cmap"


def _check_index_in_adata(adata, model):
adata_obs_index = pd.DataFrame(range(len(adata.obs.index)), index=adata.obs.index, columns=["ind"])
obs_index = (
Expand Down Expand Up @@ -204,7 +193,6 @@ def jacobian(

# Visualization.
j_keys = [f"∂{f}/∂{i}" for f in ["fx", "fy", "fz"] for i in ["x", "y", "z"]]
colormap = _get_default_cmap() if colormap is None or colormap == "default_cmap" else colormap
return three_d_multi_plot(
model=collect_models([models]),
key=j_keys,
Expand Down Expand Up @@ -350,7 +338,6 @@ def feature(
)

# Visualization.
colormap = _get_default_cmap() if colormap is None or colormap == "default_cmap" else colormap
return three_d_plot(
model=models,
key=feature_key,
Expand Down
12 changes: 11 additions & 1 deletion spateo/plotting/static/three_d_plot/three_dims_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import matplotlib as mpl
import numpy as np
import pyvista as pv
from matplotlib.colors import LinearSegmentedColormap
from pyvista import MultiBlock, Plotter, PolyData, UnstructuredGrid

try:
Expand All @@ -11,6 +12,15 @@
from typing_extensions import Literal


def _get_default_cmap():
if "default_cmap" not in mpl.colormaps():
colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"]
nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

mpl.colormaps.register(LinearSegmentedColormap.from_list("default_cmap", list(zip(nodes, colors))))
return "default_cmap"


def create_plotter(
jupyter: bool = False,
off_screen: bool = False,
Expand Down Expand Up @@ -40,6 +50,7 @@ def create_plotter(
"""

# Create an initial plotting object.
_get_default_cmap()
plotter = pv.Plotter(
off_screen=off_screen,
window_size=window_size,
Expand All @@ -58,7 +69,6 @@ def create_plotter(
plotter.add_camera_orientation_widget()
else:
plotter.add_axes()

return plotter


Expand Down

0 comments on commit b183480

Please sign in to comment.