Skip to content

Commit

Permalink
Merge pull request #253 from Sichao25/dependency
Browse files Browse the repository at this point in the history
Upgrade matplotlib
  • Loading branch information
Xiaojieqiu authored Sep 18, 2024
2 parents 22220c5 + a38c430 commit 6c6ce35
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 81 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ colorcet>=2.0.1
cvxopt>=1.2.3
csbdeep>=0.6.3
descartes
dynamo-release>=1.3.0
dynamo-release>=1.4.1
folium>=0.12.1
geopandas>=0.10.2
gpytorch
Expand All @@ -14,7 +14,7 @@ kornia>=0.6.4
leidenalg>=0.10.0
loompy>=3.0.5
mapclassify>=2.4.2
matplotlib<=3.5.3
matplotlib>=3.7.5
nbconvert
networkx>=2.6.3
# ngs_tools>=1.6.0
Expand Down
84 changes: 39 additions & 45 deletions spateo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional, Tuple, Union

import colorcet
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -247,12 +247,12 @@ def get_agg_bounds(adata: AnnData) -> Tuple[int, int, int, int]:

# Means to shift the scale of colormaps:
def shiftedColorMap(
cmap: matplotlib.colors.ListedColormap,
cmap: mpl.colors.ListedColormap,
start: float = 0,
midpoint: float = 0.5,
stop: float = 1.0,
name: str = "shiftedcmap",
) -> matplotlib.colors.ListedColormap:
) -> mpl.colors.ListedColormap:
"""
Function to offset the "center" of a colormap. Useful for
data with a negative min and positive max, and you want the
Expand All @@ -279,7 +279,7 @@ def shiftedColorMap(
newcmap: a new colormap that has the middle point of the colormap shifted.
"""
# Check for existing shifted colormap:
matplotlib.cm.ColormapRegistry.unregister(plt.colormaps, name="shiftedcmap")
mpl.cm.ColormapRegistry.unregister(plt.colormaps, name="shiftedcmap")

cdict = {"red": [], "green": [], "blue": [], "alpha": []}

Expand All @@ -299,54 +299,48 @@ def shiftedColorMap(
cdict["blue"].append((si, b, b))
cdict["alpha"].append((si, a, a))

newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
plt.register_cmap(cmap=newcmap)
newcmap = mpl.colors.LinearSegmentedColormap(name, cdict)
mpl.colormaps.register(cmap=newcmap)

return newcmap


fire_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("fire", colorcet.fire)
darkblue_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("darkblue", colorcet.kbc)
darkgreen_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("darkgreen", colorcet.kgy)
darkred_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"darkred", colors=colorcet.linear_kry_5_95_c72[:192], N=256
)
darkpurple_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("darkpurple", colorcet.linear_bmw_5_95_c89)
fire_cmap = mpl.colors.LinearSegmentedColormap.from_list("fire", colorcet.fire)
darkblue_cmap = mpl.colors.LinearSegmentedColormap.from_list("darkblue", colorcet.kbc)
darkgreen_cmap = mpl.colors.LinearSegmentedColormap.from_list("darkgreen", colorcet.kgy)
darkred_cmap = mpl.colors.LinearSegmentedColormap.from_list("darkred", colors=colorcet.linear_kry_5_95_c72[:192], N=256)
darkpurple_cmap = mpl.colors.LinearSegmentedColormap.from_list("darkpurple", colorcet.linear_bmw_5_95_c89)
# add gkr theme
div_blue_black_red_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
div_blue_black_red_cmap = mpl.colors.LinearSegmentedColormap.from_list(
"div_blue_black_red", colorcet.diverging_gkr_60_10_c40
)
# add RdBu_r theme
div_blue_red_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"div_blue_red", colorcet.diverging_bwr_55_98_c37
)
div_blue_red_cmap = mpl.colors.LinearSegmentedColormap.from_list("div_blue_red", colorcet.diverging_bwr_55_98_c37)
# add glasbey_bw for cell annotation in white background
glasbey_white_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("glasbey_white", colorcet.glasbey_bw_minc_20)
glasbey_white_cmap = mpl.colors.LinearSegmentedColormap.from_list("glasbey_white", colorcet.glasbey_bw_minc_20)
# add glasbey_bw_minc_20_maxl_70 theme for cell annotation in dark background
glasbey_dark_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
"glasbey_dark", colorcet.glasbey_bw_minc_20_maxl_70
)
glasbey_dark_cmap = mpl.colors.LinearSegmentedColormap.from_list("glasbey_dark", colorcet.glasbey_bw_minc_20_maxl_70)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
if "fire" not in matplotlib.colormaps():
plt.register_cmap("fire", fire_cmap)
if "darkblue" not in matplotlib.colormaps():
plt.register_cmap("darkblue", darkblue_cmap)
if "darkgreen" not in matplotlib.colormaps():
plt.register_cmap("darkgreen", darkgreen_cmap)
if "darkred" not in matplotlib.colormaps():
plt.register_cmap("darkred", darkred_cmap)
if "darkpurple" not in matplotlib.colormaps():
plt.register_cmap("darkpurple", darkpurple_cmap)
if "div_blue_black_red" not in matplotlib.colormaps():
plt.register_cmap("div_blue_black_red", div_blue_black_red_cmap)
if "div_blue_red" not in matplotlib.colormaps():
plt.register_cmap("div_blue_red", div_blue_red_cmap)
if "glasbey_white" not in matplotlib.colormaps():
plt.register_cmap("glasbey_white", glasbey_white_cmap)
if "glasbey_dark" not in matplotlib.colormaps():
plt.register_cmap("glasbey_dark", glasbey_dark_cmap)
if "fire" not in mpl.colormaps():
mpl.colormaps.register(cmap=fire_cmap, name="fire")
if "darkblue" not in mpl.colormaps():
mpl.colormaps.register(cmap=darkblue_cmap, name="darkblue")
if "darkgreen" not in mpl.colormaps():
mpl.colormaps.register(cmap=darkgreen_cmap, name="darkgreen")
if "darkred" not in mpl.colormaps():
mpl.colormaps.register(cmap=darkred_cmap, name="darkred")
if "darkpurple" not in mpl.colormaps():
mpl.colormaps.register(cmap=darkpurple_cmap, name="darkpurple")
if "div_blue_black_red" not in mpl.colormaps():
mpl.colormaps.register(cmap=div_blue_black_red_cmap, name="div_blue_black_red")
if "div_blue_red" not in mpl.colormaps():
mpl.colormaps.register(cmap=div_blue_red_cmap, name="div_blue_red")
if "glasbey_white" not in mpl.colormaps():
mpl.colormaps.register(cmap=glasbey_white_cmap, name="glasbey_white")
if "glasbey_dark" not in mpl.colormaps():
mpl.colormaps.register(cmap=glasbey_dark_cmap, name="glasbey_dark")

_themes = {
"fire": {
Expand Down Expand Up @@ -512,7 +506,7 @@ def config_spateo_rcParams(
background: str = "white",
prop_cycle: List[str] = zebrafish_256,
fontsize: int = 8,
color_map: matplotlib.colors.ListedColormap = None,
color_map: mpl.colors.ListedColormap = None,
frameon: Optional[bool] = None,
) -> None:
"""Configure matplotlib.rcParams to spateo defaults (based on ggplot style and scanpy).
Expand Down Expand Up @@ -735,8 +729,8 @@ def set_pub_style(scaler: float = 1) -> None:
"""

set_figure_params("spateo", background="white")
matplotlib.use("cairo")
matplotlib.rcParams.update({"font.size": 6 * scaler})
mpl.use("cairo")
mpl.rcParams.update({"font.size": 6 * scaler})
params = {
"font.size": 6 * scaler,
"legend.fontsize": 6 * scaler,
Expand All @@ -748,13 +742,13 @@ def set_pub_style(scaler: float = 1) -> None:
"axes.titlepad": 1 * scaler,
"axes.labelpad": 1 * scaler,
}
matplotlib.rcParams.update(params)
mpl.rcParams.update(params)


def set_pub_style_mpltex():
"""formatting helper function based on mpltex package that can be used to save publishable figures"""
set_figure_params("spateo", background="white")
matplotlib.use("cairo")
mpl.use("cairo")
# the following code is adapted from https://github.com/liuyxpp/mpltex
# latex_preamble = r"\usepackage{siunitx}\sisetup{detect-all}\usepackage{helvet}\usepackage[eulergreek,EULERGREEK]{sansmath}\sansmath"
params = {
Expand Down Expand Up @@ -811,4 +805,4 @@ def set_pub_style_mpltex():
"axes.titlepad": 1,
"axes.labelpad": 1,
}
matplotlib.rcParams.update(params)
mpl.rcParams.update(params)
3 changes: 2 additions & 1 deletion spateo/digitization/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def gen_cluster_image(
prepared from the designated cmap.
"""

import matplotlib as mpl
import matplotlib.pyplot as plt

if bin_size is None:
Expand All @@ -49,7 +50,7 @@ def gen_cluster_image(
lm.main_info(f"Set up the color for the clusters with the {cmap} colormap.")

# TODO: what if cluster number is larger than cmap.N?
cmap = plt.cm.get_cmap(cmap)
cmap = mpl.colormaps[cmap]
colors = cmap(np.arange(cmap.N))
color_ls = []
for i in range(cmap.N):
Expand Down
2 changes: 1 addition & 1 deletion spateo/plotting/static/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def plot_connections(

# Set label colors:
if isinstance(colormap, str):
cmap = mpl.cm.get_cmap(colormap)
cmap = mpl.colormaps[colormap]
else:
cmap = colormap

Expand Down
3 changes: 2 additions & 1 deletion spateo/plotting/static/scatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def scatters(
then this will simply display inline.
"""

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib.colors import rgb2hex, to_hex
Expand Down Expand Up @@ -776,7 +777,7 @@ def _plot_basis_layer(cur_b, cur_l):
if stack_colors:
# main_debug("stack colors: changing cmap")
_cmap = stack_colors_cmaps[ax_index % len(stack_colors_cmaps)]
max_color = matplotlib.cm.get_cmap(_cmap)(float("inf"))
max_color = mpl.colormaps[_cmap](float("inf"))
legend_circle = Line2D(
[0],
[0],
Expand Down
2 changes: 1 addition & 1 deletion spateo/plotting/static/three_d_plot/three_dims_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ def visualize_3D_increasing_direction_gradient(
1 - (1 - coords_norm) * (1 - new_center) / 0.5, # Compress the upper half
)

colors = mpl.cm.get_cmap(cmap)(coords_norm)
colors = mpl.colormaps[cmap](coords_norm)
# Convert colors to hex format:
colors = ["#" + "".join([f"{int(c * 255):02x}" for c in color[:3]]) for color in colors]

Expand Down
34 changes: 16 additions & 18 deletions spateo/plotting/static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
warnings.simplefilter("ignore")
import geopandas as gpd

import matplotlib
import matplotlib as mpl
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import mpl_toolkits
Expand Down Expand Up @@ -75,23 +75,21 @@ def _get_adata_color_vec(adata, layer, col):


def map2color(val, min=None, max=None, cmap="viridis"):
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt

minima = np.min(val) if min is None else min
maxima = np.max(val) if max is None else max

norm = matplotlib.colors.Normalize(vmin=minima, vmax=maxima, clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=plt.get_cmap(cmap))
norm = mpl.colors.Normalize(vmin=minima, vmax=maxima, clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=mpl.colormaps[cmap])

cols = [mapper.to_rgba(v) for v in val]

return cols


def _to_hex(arr):
return [matplotlib.colors.to_hex(c) for c in arr]
return [mpl.colors.to_hex(c) for c in arr]


# https://stackoverflow.com/questions/8468855/convert-a-rgb-colour-value-to-decimal
Expand Down Expand Up @@ -399,7 +397,7 @@ def _matplotlib_points(
)
if color_key is None:
# main_debug("color_key is None")
cmap = copy.copy(matplotlib.cm.get_cmap(color_key_cmap))
cmap = copy.copy(mpl.colormaps[color_key_cmap])
cmap.set_bad("lightgray")
colors = None

Expand Down Expand Up @@ -594,12 +592,12 @@ def _matplotlib_points(
# Color by values
elif values is not None:
# main_debug("drawing points by values")
cmap_ = copy.copy(matplotlib.cm.get_cmap(cmap))
cmap_ = copy.copy(mpl.colormaps[cmap])
cmap_.set_bad("lightgray")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
matplotlib.cm.register_cmap(name=cmap_.name, cmap=cmap_, override_builtin=True)
mpl.colormaps.register(name=cmap_.name, cmap=cmap_, force=True)

if values.shape[0] != points.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -780,9 +778,9 @@ def _matplotlib_points(
if "norm" in kwargs:
norm = kwargs["norm"]
else:
norm = matplotlib.colors.Normalize(vmin=_vmin, vmax=_vmax)
norm = mpl.colors.Normalize(vmin=_vmin, vmax=_vmax)

mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array(values)
if show_colorbar:
cb = plt.colorbar(mappable, cax=set_colorbar(ax, inset_dict), ax=ax)
Expand All @@ -791,12 +789,12 @@ def _matplotlib_points(
cb.locator = MaxNLocator(nbins=3, integer=True)
cb.update_ticks()

cmap = matplotlib.cm.get_cmap(cmap)
cmap = mpl.colormaps[cmap]
colors = cmap(values)
# No color (just pick the midpoint of the cmap)
else:
# main_debug("drawing points without color passed in args, using midpoint of the cmap")
colors = plt.get_cmap(cmap)(0.5)
colors = mpl.colormaps[cmap](0.5)
if geo:
_geo_projection(ax, points, color=colors, **kwargs)

Expand All @@ -823,7 +821,7 @@ def _matplotlib_points(
ax.legend(
handles=legend_elements,
bbox_to_anchor=(1.04, 1),
loc=matplotlib.rcParams["legend.loc"],
loc=mpl.rcParams["legend.loc"],
ncol=len(unique_labels) // 20 + 1,
prop=dict(size=8),
)
Expand Down Expand Up @@ -1486,8 +1484,8 @@ def save_return_show_fig_utils(
prefix: str,
save_kwargs: Dict,
total_panels: int,
fig: matplotlib.figure.Figure,
axes: matplotlib.axes.Axes,
fig: mpl.figure.Figure,
axes: mpl.axes.Axes,
return_all: bool,
return_all_list: Union[List, Tuple, None],
) -> Optional[Tuple]:
Expand Down Expand Up @@ -1585,7 +1583,7 @@ def check_colornorm(
vmin: Union[None, float] = None,
vmax: Union[None, float] = None,
vcenter: Union[None, float] = None,
norm: Union[None, matplotlib.colors.Normalize] = None,
norm: Union[None, mpl.colors.Normalize] = None,
):
"""
When plotting continuous variables, configure a normalizer object for the purposes of mapping the data to varying
Expand Down Expand Up @@ -1824,7 +1822,7 @@ def dendrogram(


def plot_dendrogram(
dendro_ax: matplotlib.axes.Axes,
dendro_ax: mpl.axes.Axes,
adata: AnnData,
cat_key: str,
dendrogram_key: Union[None, str] = None,
Expand Down
2 changes: 1 addition & 1 deletion spateo/tdr/models/utilities/label_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def add_model_labels(
# Set raw hex.
if isinstance(colormap, str):
if colormap in list(mpl.colormaps()):
lscmap = mpl.cm.get_cmap(colormap)
lscmap = mpl.colormaps[colormap]
raw_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(cu_arr))]
for label, color in zip(cu_arr, raw_hex_list):
raw_labels_hex[raw_labels_hex == label] = color
Expand Down
Loading

0 comments on commit 6c6ce35

Please sign in to comment.