diff --git a/legendkit/_colorart.py b/legendkit/_colorart.py index 744cb62..c4ec6a1 100644 --- a/legendkit/_colorart.py +++ b/legendkit/_colorart.py @@ -153,11 +153,17 @@ def __init__(self, rasterized=True, ): + super().__init__() if ax is None: ax = plt.gca() - self.ax = ax + self.is_axes = True + if not isinstance(ax, Axes): + self.is_axes = False + self.set_figure(ax) - super().__init__() + else: + self.set_figure(ax.figure) + self.axes = ax if rasterized: # Force rasterization self._rasterized = True @@ -265,6 +271,7 @@ def __init__(self, Locs().transform(ax, loc, bbox_to_anchor=bbox_to_anchor, bbox_transform=bbox_transform, deviation=deviation) + print(self._bbox_to_anchor) self.textpad = mpl.rcParams[ 'legend.handletextpad'] if textpad is None else textpad @@ -314,6 +321,7 @@ def _make_cbar_box(self): # Add cbar canvas = DrawingArea(da_width, da_height, clip=False) # self._add_color_patches(self._cbar_canvas) + canvas.set_figure(self.figure) cmap_caller = get_colormap(self.cmap) colors_list = cmap_caller(np.arange(cmap_caller.N)) @@ -397,7 +405,7 @@ def _make_cbar_box(self): children=[title_canvas, canvas], align=self.alignment ) - title_pack.axes = self.ax + title_pack.set_figure(self.figure) final_pack = title_pack else: final_pack = canvas @@ -408,17 +416,23 @@ def _make_cbar_box(self): bbox_transform=self._bbox_transform, bbox_to_anchor=self._bbox_to_anchor, frameon=False) - self.ax.add_artist(self._cbar_box) + self._cbar_box.set_figure(self.figure) + if self.is_axes: + self.axes.add_artist(self._cbar_box) + else: + self.figure.add_artist(self._cbar_box) def _get_text_size(self, ticklabels): """Used to get the proper size for drawing area""" - fig = self.ax.get_figure() - renderer = fig.canvas.get_renderer() + renderer = self.figure.canvas.get_renderer() all_texts = [] for t in ticklabels: text_obj = Text(0, 0, t, fontsize=self._fontsize, fontproperties=self.prop) - self.ax.add_artist(text_obj) + if self.is_axes: + self.axes.add_artist(text_obj) + else: + self.figure.add_artist(text_obj) all_texts.append(text_obj) x_sizes, y_sizes = [], [] for t in all_texts: diff --git a/legendkit/_legend.py b/legendkit/_legend.py index 22cf996..a2f7dcd 100644 --- a/legendkit/_legend.py +++ b/legendkit/_legend.py @@ -4,14 +4,16 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib import _api -from matplotlib.collections import Collection -from matplotlib.colors import is_color_like +from matplotlib.axes import Axes +from matplotlib.collections import Collection, PatchCollection +from matplotlib.colors import is_color_like, Normalize +from matplotlib.figure import FigureBase from matplotlib.font_manager import FontProperties from matplotlib.legend import Legend from matplotlib.lines import Line2D from matplotlib.markers import MarkerStyle from matplotlib.offsetbox import VPacker, HPacker -from matplotlib.patches import Patch +from matplotlib.patches import Patch, Rectangle from ._handlers import CircleHandler, RectHandler, BoxplotHanlder from ._locs import Locs @@ -112,7 +114,7 @@ class ListLegend(Legend): Parameters ---------- - ax : :class:`Axes ` + ax : :class:`Axes ` or :class:`Figure ` The axes to draw the legend legend_items : array-like of (handle, label, styles) See examples @@ -183,7 +185,7 @@ def __repr__(self): return "" def __init__(self, - ax=None, + ax: Axes | FigureBase = None, legend_items=None, handles=None, labels=None, @@ -206,9 +208,16 @@ def __init__(self, _api.check_in_list(["top", "bottom", "left", "right"], title_loc=title_loc) - self._has_axes = ax is not None + self._has_parent = ax is not None + self._is_axes = isinstance(ax, Axes) if ax is None: - ax = plt.gca() + axes = plt.gca() + else: + if not self._is_axes: + fig = ax + axes = fig.get_axes() + else: + axes = [ax] self._title_loc = title_loc self.titlepad = titlepad self._is_patch = False @@ -237,9 +246,13 @@ def val_or_rc(val, rc_name): legend_labels = [] if (legend_items is None) & (handles is None) & (labels is None): - # If only axes is provided, we will try to get - legend_handles, legend_labels = \ - ax.get_legend_handles_labels(handler_map) + legend_handles = [] + legend_labels = [] + for handle in _get_legend_handles(axes, handler_map): + label = handle.get_label() + if label and not label.startswith('_'): + legend_handles.append(handle) + legend_labels.append(label) elif legend_items is not None: for item in legend_items: if len(item) == 2: @@ -262,7 +275,10 @@ def val_or_rc(val, rc_name): legend_handles, legend_labels = handles, labels if loc is None: - loc = "best" + if self._is_axes: + loc = "best" + else: + loc = "center right" else: loc, bbox_to_anchor, bbox_transform = \ Locs().transform(ax, loc, bbox_to_anchor=bbox_to_anchor, @@ -301,10 +317,13 @@ def val_or_rc(val, rc_name): # Attach as legend element # 1. ax.get_legend() will work # 2. legend won't be clipped - if ax.legend_ is None: - ax.legend_ = self + if isinstance(ax, Axes): + if ax.legend_ is None: + ax.legend_ = self + else: + ax.add_artist(self) else: - ax.add_artist(self) + ax.legends.append(self) def _parse_handler(self, handle, handle_size, config=None): if not isinstance(handle, str): @@ -460,7 +479,7 @@ def _get_default_handle_option(handle, fill, color): else: if fill: return {'fc': color, 'ec': color} - return {'fc': 'none', 'ec': color,} + return {'fc': 'none', 'ec': color, } # Modified from mpl.collections.PathCollection.legend_elements diff --git a/legendkit/_locs.py b/legendkit/_locs.py index cd2a374..7ed16a8 100644 --- a/legendkit/_locs.py +++ b/legendkit/_locs.py @@ -1,3 +1,6 @@ +from matplotlib.axes import Axes + + def add_x(x, y, offset): return x + offset, y @@ -14,6 +17,10 @@ def minus_y(x, y, offset): return x, y - offset +def blank(x, y, offset): + return x, y + + class Locs: combs = { 'out upper left': ('lower left', (0, 1), add_y), @@ -31,6 +38,11 @@ class Locs: 'out right upper': ('upper left', (1, 1), add_x), 'out right center': ('center left', (1, 0.5), add_x), 'out right lower': ('lower left', (1, 0), add_x), + + 'lower left': ('lower left', (0, 0), blank), + + 'center left': ('center left', (0, 0.5), blank), + 'center right': ('center right', (1, 0.5), blank), } LOC_OPTIONS = [ @@ -60,7 +72,10 @@ def transform(self, loc = replacement[0] bbox = replacement[1] offset_func = replacement[2] - bbox = offset_func(*bbox, deviation) - return loc, bbox, ax.transAxes - else: - return loc, bbox_to_anchor, bbox_transform + bbox_to_anchor = offset_func(*bbox, deviation) + if isinstance(ax, Axes): + bbox_transform = ax.transAxes + else: + fig = ax.get_figure() + bbox_transform = fig.transSubfigure + return loc, bbox_to_anchor, bbox_transform