Skip to content

Commit

Permalink
Legend can be placed at figure level placement #22
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Mar 30, 2024
1 parent 7dca1a1 commit 6580e4e
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
28 changes: 21 additions & 7 deletions legendkit/_colorart.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,17 @@ def __init__(self,
rasterized=True,

):
super().__init__()
if ax is None:
ax = plt.gca()
self.ax = ax
self.is_axes = True
if not isinstance(ax, Axes):
self.is_axes = False
self.set_figure(ax)

super().__init__()
else:
self.set_figure(ax.figure)
self.axes = ax
if rasterized:
# Force rasterization
self._rasterized = True
Expand Down Expand Up @@ -265,6 +271,7 @@ def __init__(self,
Locs().transform(ax, loc, bbox_to_anchor=bbox_to_anchor,
bbox_transform=bbox_transform,
deviation=deviation)
print(self._bbox_to_anchor)

self.textpad = mpl.rcParams[
'legend.handletextpad'] if textpad is None else textpad
Expand Down Expand Up @@ -314,6 +321,7 @@ def _make_cbar_box(self):
# Add cbar
canvas = DrawingArea(da_width, da_height, clip=False)
# self._add_color_patches(self._cbar_canvas)
canvas.set_figure(self.figure)

cmap_caller = get_colormap(self.cmap)
colors_list = cmap_caller(np.arange(cmap_caller.N))
Expand Down Expand Up @@ -397,7 +405,7 @@ def _make_cbar_box(self):
children=[title_canvas, canvas],
align=self.alignment
)
title_pack.axes = self.ax
title_pack.set_figure(self.figure)
final_pack = title_pack
else:
final_pack = canvas
Expand All @@ -408,17 +416,23 @@ def _make_cbar_box(self):
bbox_transform=self._bbox_transform,
bbox_to_anchor=self._bbox_to_anchor,
frameon=False)
self.ax.add_artist(self._cbar_box)
self._cbar_box.set_figure(self.figure)
if self.is_axes:
self.axes.add_artist(self._cbar_box)
else:
self.figure.add_artist(self._cbar_box)

def _get_text_size(self, ticklabels):
"""Used to get the proper size for drawing area"""
fig = self.ax.get_figure()
renderer = fig.canvas.get_renderer()
renderer = self.figure.canvas.get_renderer()
all_texts = []
for t in ticklabels:
text_obj = Text(0, 0, t, fontsize=self._fontsize,
fontproperties=self.prop)
self.ax.add_artist(text_obj)
if self.is_axes:
self.axes.add_artist(text_obj)
else:
self.figure.add_artist(text_obj)
all_texts.append(text_obj)
x_sizes, y_sizes = [], []
for t in all_texts:
Expand Down
49 changes: 34 additions & 15 deletions legendkit/_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import _api
from matplotlib.collections import Collection
from matplotlib.colors import is_color_like
from matplotlib.axes import Axes
from matplotlib.collections import Collection, PatchCollection
from matplotlib.colors import is_color_like, Normalize
from matplotlib.figure import FigureBase
from matplotlib.font_manager import FontProperties
from matplotlib.legend import Legend
from matplotlib.lines import Line2D
from matplotlib.markers import MarkerStyle
from matplotlib.offsetbox import VPacker, HPacker
from matplotlib.patches import Patch
from matplotlib.patches import Patch, Rectangle

from ._handlers import CircleHandler, RectHandler, BoxplotHanlder
from ._locs import Locs
Expand Down Expand Up @@ -112,7 +114,7 @@ class ListLegend(Legend):
Parameters
----------
ax : :class:`Axes <matplotlib.axes.Axes>`
ax : :class:`Axes <matplotlib.axes.Axes>` or :class:`Figure <matplotlib.figure.FigureBase>`
The axes to draw the legend
legend_items : array-like of (handle, label, styles)
See examples
Expand Down Expand Up @@ -183,7 +185,7 @@ def __repr__(self):
return "<ListLegend>"

def __init__(self,
ax=None,
ax: Axes | FigureBase = None,
legend_items=None,
handles=None,
labels=None,
Expand All @@ -206,9 +208,16 @@ def __init__(self,

_api.check_in_list(["top", "bottom", "left", "right"],
title_loc=title_loc)
self._has_axes = ax is not None
self._has_parent = ax is not None
self._is_axes = isinstance(ax, Axes)
if ax is None:
ax = plt.gca()
axes = plt.gca()
else:
if not self._is_axes:
fig = ax
axes = fig.get_axes()
else:
axes = [ax]
self._title_loc = title_loc
self.titlepad = titlepad
self._is_patch = False
Expand Down Expand Up @@ -237,9 +246,13 @@ def val_or_rc(val, rc_name):
legend_labels = []

if (legend_items is None) & (handles is None) & (labels is None):
# If only axes is provided, we will try to get
legend_handles, legend_labels = \
ax.get_legend_handles_labels(handler_map)
legend_handles = []
legend_labels = []
for handle in _get_legend_handles(axes, handler_map):
label = handle.get_label()
if label and not label.startswith('_'):
legend_handles.append(handle)
legend_labels.append(label)
elif legend_items is not None:
for item in legend_items:
if len(item) == 2:
Expand All @@ -262,7 +275,10 @@ def val_or_rc(val, rc_name):
legend_handles, legend_labels = handles, labels

if loc is None:
loc = "best"
if self._is_axes:
loc = "best"
else:
loc = "center right"
else:
loc, bbox_to_anchor, bbox_transform = \
Locs().transform(ax, loc, bbox_to_anchor=bbox_to_anchor,
Expand Down Expand Up @@ -301,10 +317,13 @@ def val_or_rc(val, rc_name):
# Attach as legend element
# 1. ax.get_legend() will work
# 2. legend won't be clipped
if ax.legend_ is None:
ax.legend_ = self
if isinstance(ax, Axes):
if ax.legend_ is None:
ax.legend_ = self
else:
ax.add_artist(self)
else:
ax.add_artist(self)
ax.legends.append(self)

def _parse_handler(self, handle, handle_size, config=None):
if not isinstance(handle, str):
Expand Down Expand Up @@ -460,7 +479,7 @@ def _get_default_handle_option(handle, fill, color):
else:
if fill:
return {'fc': color, 'ec': color}
return {'fc': 'none', 'ec': color,}
return {'fc': 'none', 'ec': color, }


# Modified from mpl.collections.PathCollection.legend_elements
Expand Down
23 changes: 19 additions & 4 deletions legendkit/_locs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from matplotlib.axes import Axes


def add_x(x, y, offset):
return x + offset, y

Expand All @@ -14,6 +17,10 @@ def minus_y(x, y, offset):
return x, y - offset


def blank(x, y, offset):
return x, y


class Locs:
combs = {
'out upper left': ('lower left', (0, 1), add_y),
Expand All @@ -31,6 +38,11 @@ class Locs:
'out right upper': ('upper left', (1, 1), add_x),
'out right center': ('center left', (1, 0.5), add_x),
'out right lower': ('lower left', (1, 0), add_x),

'lower left': ('lower left', (0, 0), blank),

'center left': ('center left', (0, 0.5), blank),
'center right': ('center right', (1, 0.5), blank),
}

LOC_OPTIONS = [
Expand Down Expand Up @@ -60,7 +72,10 @@ def transform(self,
loc = replacement[0]
bbox = replacement[1]
offset_func = replacement[2]
bbox = offset_func(*bbox, deviation)
return loc, bbox, ax.transAxes
else:
return loc, bbox_to_anchor, bbox_transform
bbox_to_anchor = offset_func(*bbox, deviation)
if isinstance(ax, Axes):
bbox_transform = ax.transAxes
else:
fig = ax.get_figure()
bbox_transform = fig.transSubfigure
return loc, bbox_to_anchor, bbox_transform

0 comments on commit 6580e4e

Please sign in to comment.