diff --git a/spateo/plotting/static/three_d_plot/morphometrics_plots.py b/spateo/plotting/static/three_d_plot/morphometrics_plots.py index a2724b73..9c6046dd 100644 --- a/spateo/plotting/static/three_d_plot/morphometrics_plots.py +++ b/spateo/plotting/static/three_d_plot/morphometrics_plots.py @@ -1,10 +1,8 @@ from typing import Optional, Union -import matplotlib as mpl import numpy as np import pandas as pd from anndata import AnnData -from matplotlib.colors import LinearSegmentedColormap from pyvista import MultiBlock, PolyData, UnstructuredGrid from ....tdr import add_model_labels, collect_models @@ -16,15 +14,6 @@ from typing_extensions import Literal -def _get_default_cmap(): - if "default_cmap" not in mpl.colormaps(): - colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"] - nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] - - mpl.colormaps.register(LinearSegmentedColormap.from_list("default_cmap", list(zip(nodes, colors)))) - return "default_cmap" - - def _check_index_in_adata(adata, model): adata_obs_index = pd.DataFrame(range(len(adata.obs.index)), index=adata.obs.index, columns=["ind"]) obs_index = ( @@ -204,7 +193,6 @@ def jacobian( # Visualization. j_keys = [f"∂{f}/∂{i}" for f in ["fx", "fy", "fz"] for i in ["x", "y", "z"]] - colormap = _get_default_cmap() if colormap is None or colormap == "default_cmap" else colormap return three_d_multi_plot( model=collect_models([models]), key=j_keys, @@ -350,7 +338,6 @@ def feature( ) # Visualization. - colormap = _get_default_cmap() if colormap is None or colormap == "default_cmap" else colormap return three_d_plot( model=models, key=feature_key, diff --git a/spateo/plotting/static/three_d_plot/three_dims_plotter.py b/spateo/plotting/static/three_d_plot/three_dims_plotter.py index 8652a4bb..691b0525 100644 --- a/spateo/plotting/static/three_d_plot/three_dims_plotter.py +++ b/spateo/plotting/static/three_d_plot/three_dims_plotter.py @@ -3,6 +3,7 @@ import matplotlib as mpl import numpy as np import pyvista as pv +from matplotlib.colors import LinearSegmentedColormap from pyvista import MultiBlock, Plotter, PolyData, UnstructuredGrid try: @@ -11,6 +12,15 @@ from typing_extensions import Literal +def _get_default_cmap(): + if "default_cmap" not in mpl.colormaps(): + colors = ["#4B0082", "#800080", "#F97306", "#FFA500", "#FFD700", "#FFFFCB"] + nodes = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + + mpl.colormaps.register(LinearSegmentedColormap.from_list("default_cmap", list(zip(nodes, colors)))) + return "default_cmap" + + def create_plotter( jupyter: bool = False, off_screen: bool = False, @@ -40,6 +50,7 @@ def create_plotter( """ # Create an initial plotting object. + _get_default_cmap() plotter = pv.Plotter( off_screen=off_screen, window_size=window_size, @@ -58,7 +69,6 @@ def create_plotter( plotter.add_camera_orientation_widget() else: plotter.add_axes() - return plotter