-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyaptool.py
628 lines (490 loc) · 17.8 KB
/
yaptool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
"""A collection of handy functions to avoid boilerplate code while using matplotlib."""
__version__ = "0.1"
from typing import List, Literal, Optional, Tuple, Union
import matplotlib # type: ignore
import matplotlib.figure # type: ignore
import matplotlib.pyplot as plt # type: ignore
from matplotlib import rc # type: ignore
from matplotlib.patches import Rectangle # type: ignore
####################
# Internal helpers #
####################
SPINES = Tuple[Union[Literal["top"], Literal["bottom"], Literal["left"],
Literal["right"]], ...]
def _set_fgbg(fg_col: str, bg_col: str):
''' Internal helper to change fore- and background colours '''
plt.rcParams.update({
"lines.color": fg_col,
"patch.edgecolor": fg_col,
"text.color": fg_col,
"axes.facecolor": bg_col,
"axes.edgecolor": fg_col,
"axes.labelcolor": fg_col,
"xtick.color": fg_col,
"ytick.color": fg_col,
"grid.color": fg_col,
"figure.facecolor": bg_col,
"figure.edgecolor": bg_col,
"savefig.facecolor": bg_col,
"savefig.edgecolor": bg_col
})
######################
# General Aesthetics #
######################
def darkmode(foreground: str = "0.85", background: str = "0.15") -> None:
"""Switches to dark mode.
Foreground and background colours may also be specified explicitly.
Args:
foreground:
An optional string, specifying the foreground colour,
following matplotlib's colour syntax. Defaults to "0.85", i.e. light grey.
background:
An optional string, specifying the background colour,
following matplotlib's colour syntax. Defaults to "0.15", i.e. dark grey.
Returns:
None
"""
_set_fgbg(fg_col=foreground, bg_col=background)
def lightmode(foreground: str = "0", background: str = "1.0") -> None:
"""Switches to light mode.
Foreground and background colours may also be specified explicitly.
Args:
foreground:
An optional string, specifying the foreground colour,
following matplotlib's colour syntax. Defaults to "0", i.e. black.
background:
An optional string, specifying the background colour,
following matplotlib's colour syntax. Defaults to "1.0", i.e. white.
Returns:
None
"""
_set_fgbg(fg_col=foreground, bg_col=background)
def texon() -> None:
"""Switches on TeX-rendering of texts.
Args:
None
Returns:
None
"""
rc('text', usetex=True)
params = {'text.latex.preamble': r'\usepackage{amsmath}'}
plt.rcParams.update(params)
def texoff() -> None:
"""Switches off TeX-rendering of texts.
Args:
None
Returns:
None
"""
rc('text', usetex=False)
####################
# Types of layouts #
####################
def singleplot(size: Tuple[float, float] = (
10, 7)) -> Tuple[matplotlib.figure.Figure, plt.Axes]:
"""Generates a new single-plot figure.
The figure size may be defined explicitly.
Args:
size:
An optional tuple of two floats, containing
the desired figure width and heigth in inches.
Defaults to 10x7 inches.
Returns:
fig:
A matplotlib.figure.Figure instance
ax:
A pyplot.Axes instance
"""
fig, ax = plt.subplots(1, 1, figsize=size)
return fig, ax
def multiplot(
nrows: int,
ncols: int,
size_xy: Tuple[float, float],
wspace: Optional[float] = None,
hspace: Optional[float] = None
) -> Tuple[matplotlib.figure.Figure, plt.Axes]:
"""Generates a new figure consisting of nrows rows
and ncols columns of plots with overall figure size size_xy.
Horizontal and vertical distance between plots may be defined explicitly.
Args:
size:
An tuple of two floats, containing the desired
figure width and heigth in inches.
wspace:
An optional float, specifying the horizontal
distance between columns. Defaults to zero.
hspace:
An optional float, specifying the vertical
distance between rows. Defaults to zero.
Returns:
fig:
A matplotlib.figure.Figure instance
ax:
An array of pyplot.Axes instances
"""
fig, ax = plt.subplots(nrows, ncols, figsize=size_xy)
if hspace is not None:
plt.subplots_adjust(hspace=hspace)
if wspace is not None:
plt.subplots_adjust(wspace=wspace)
return fig, ax
#############################
# Adding elements to a plot #
#############################
def title(ax: plt.Axes,
plottitle: str,
fontsize: float = 30,
pad: float = 20) -> None:
"""Adds a title to an existing plot.
Args:
ax:
A pyplot.Axes instance
plottitle:
A string containing the title to add.
fontsize:
An optional float, specifying the font
size of the title. Defaults to 30.
pad:
An optional float, specifying the padding
between title and figure. Defaults to 20.
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
ax.set_title(plottitle, fontsize=fontsize, pad=pad)
def labels(ax: plt.Axes,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
fontsize: float = 30,
pad: float = 15) -> None:
"""Adds axes labels to an existing plot.
Args:
ax:
A pyplot.Axes instance
xlabel:
An optional string containing the label to be
added to the x-axis of the plot.
ylabel:
An optional string containing the label to be
added to the y-axis of the plot.
fontsize:
An optional float, specifying the font size of
the labels. Defaults to 30.
pad:
An optional float, specifying the padding between
labels and figure. Defaults to 15.
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if xlabel is not None:
ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=pad)
if ylabel is not None:
ax.set_ylabel(ylabel, fontsize=fontsize, labelpad=pad)
def diagonal(ax: plt.Axes,
colour: str = "black",
alpha: float = 0.3,
linestyle: str = "-",
linewidth: float = 2) -> None:
"""Adds the 45 degrees diagonal to an existing plot.
Args:
ax:
A pyplot.Axes instance
colour:
An optional string containing the colour of the diagonal
to be added. Defaults to "black".
alpha:
An optional float containing the alpha (opacity) of the
diagonal to be added. Defaults to 0.3.
linestyle:
An optional string, specifying the line style of the
diagonal to be added, Defaults to "-", i.e. a continuous line.
linewidth:
An optional float, specifying the width (in px) of the
diagonal to be added. Defaults to 2.
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
minimum = min(ax.get_xlim()[0], ax.get_ylim()[0])
maximum = max(ax.get_xlim()[1], ax.get_ylim()[1])
ax.plot([minimum, maximum], [minimum, maximum],
color=colour,
linestyle=linestyle,
linewidth=linewidth,
alpha=alpha)
def rectangle(ax: plt.Axes, x1: float, y1: float, x2: float, y2: float,
**kwargs) -> None:
"""Convenience function for addig a rectangle to an existing plot
without having to manually call ax.add_patch(). Takes x and y
coordinates of two points as arguments, instead of the x and y
coordinate of one point and the rectagle width and heigth,
like add_patch() would.
Args:
ax:
A pyplot.Axes instance
x1:
A float, specifying the x coordinate of the first point.
y1:
A float, specifying the y coordinate of the first point.
x2:
A float, specifying the x coordinate of the second point.
y2:
A float, specifying the y coordinate of the second point.
**kwargs:
Named arguments such as color, fill, linewidth, linestyle.
Passed to ax.add_patch().
Returns:
None
"""
try:
x1 = float(x1)
y1 = float(y1)
x2 = float(x2)
y2 = float(y2)
except Exception as ex:
raise ValueError("Pass numbers in x1, y1, x2, y2.") from ex
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
ax.add_patch(Rectangle((x1, y1), x2 - x1, y2 - y1, **kwargs))
def legend(ax: plt.Axes,
loc: Union["str", int] = "best",
fontsize: float = 30,
frame: bool = False,
**kwargs) -> None:
"""Adds a legend to an existing plot.
Args:
ax:
A pyplot.Axes instance
loc:
A string, specifying the legend position,
following matplotlib syntax. Defaults to "best".
fontsize:
A float, specifying the font size. Defaults to 30.
frame:
A bool, specifying whether to draw a frame around the legend.
Defaults to false, i.e. no.
**kwargs:
Named arguments such as color, fill, linewidth, linestyle.
Passed to ax.legend().
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
ax.legend(loc=loc, fontsize=fontsize, frameon=frame, **kwargs)
#############################
# Change elements of a plot #
#############################
def despine(ax: plt.Axes, which: SPINES = ('top', 'right')) -> None:
"""Remove spines of an existing plot.
Spines can be specified, default is top and right.
Args:
ax:
A pyplot.Axes instance
which:
A tuple of strings, specifying which spines to remove.
Defaults to ["top", "right"].
Also possible are "bottom" and "left".
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
for spine in which:
ax.spines[spine].set_visible(False)
def respine(ax: plt.Axes, which: SPINES = ('top', 'right')) -> None:
"""Adds spines to an existing plot.
Spines can be specified, default is top and right.
Args:
ax:
A pyplot.Axes instance
which:
A tuple of strings, specifying which spines to add.
Defaults to ["top", "right"].
Also possible are "bottom" and "left".
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
for spine in which:
ax.spines[spine].set_visible(True)
def ticklabelsize(ax: plt.Axes, which: str = "both", size: float = 30) -> None:
"""Changes ticklabelsize of an existing plot.
Args:
ax:
A pyplot.Axes instance
which:
A string, specifying the axes for which tick label size is changed.
Possible are "x", "y", and "both". Defaults to "both".
size:
A float, specifying the desired tick label size. Defaults to 30.
Returns:
None
"""
if which not in ["x", "y", "xy", "yx", "both"]:
raise ValueError(
'Parameter which must be one of "x", "y", "xy", "yx", "both".')
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if which in ["x", "xy", "yx", "both"]:
ax.tick_params("x", labelsize=size)
if which in ["y", "xy", "yx", "both"]:
ax.tick_params("y", labelsize=size)
def limits(ax: plt.Axes,
xlimits: Optional[Tuple[float, float]] = None,
ylimits: Optional[Tuple[float, float]] = None) -> None:
"""Sets ax limits of an existing plot.
Args:
ax:
A pyplot.Axes instance
xlimits:
An optional tuple of two floats,
containing the desired limits of the x-axis. Defaults to None.
ylimits:
An optional tuple of two floats,
containing the desired limits of the x-axis. Defaults to None.
Returns:
None
"""
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if xlimits is not None:
ax.set_xlim(xlimits)
if ylimits is not None:
ax.set_ylim(ylimits)
def ticks_and_labels(ax: plt.Axes,
which: str,
ticks: List[float],
ticklabels: Optional[List[str]] = None) -> None:
"""Sets ticks and corresponding labels of one or both axes of an existing plot.
Args:
ax:
A pyplot.Axes instance
which:
A string, specifying the axis.
Possible values are "x", "y", "xy", "yx", "both".
ticks:
A list of floats, containing the desired tick positions.
ticklabels:
An optional list of strings, containing the desired
axis labels corresponding to the specified tick positions.
Defaults to None. If no list is provided, tick labels will
be set to the numerical values of the provided ticks positions.
Returns:
None
"""
if which not in ["x", "y", "xy", "yx", "both"]:
raise ValueError(
'Parameter which must be one of "x", "y", "xy", "yx", "both".')
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if ticklabels is None:
ticklabels = [str(tick) for tick in ticks]
if which in ["x", "xy", "yx", "both"]:
ax.set_xticks(ticks)
ax.set_xticklabels(ticklabels)
if which in ["y", "xy", "yx", "both"]:
ax.set_yticks(ticks)
ax.set_yticklabels(ticklabels)
def rotate_ticklabels(ax: plt.Axes, which: str, rotation: float) -> None:
"""Rotates tick labels of one or both axes of an existing plot.
Args:
ax:
A pyplot.Axes instance
which:
A string, specifying the axis. Possible values are "x", "y", "xy", "yx", "both".
rotation:
A float, containing the tick label angle.
Returns:
None
"""
if which not in ["x", "y", "xy", "yx", "both"]:
raise ValueError(
'Parameter which must be one of "x", "y", "xy", "yx", "both".')
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if which in ["x", "xy", "yx", "both"]:
ax.set_xticklabels(ax.get_xticklabels(), rotation=rotation)
if which in ["y", "xy", "yx", "both"]:
ax.set_yticklabels(ax.get_yticklabels(), rotation=rotation)
def align_ticklabels(ax: plt.Axes,
which: str,
horizontal: Optional[str] = None,
vertical: Optional[str] = None) -> None:
"""Aligns tick labels of one axis of an existing plot.
Both horizontal and vertical alignment may be specified.
Args:
ax:
A pyplot.Axes instance
which:
A string, specifying the axis. Possible values are "x", "y".
horizontal:
An optional string, containing the desired horizontal
tick label alignment. Possible values are "center",
"right", "left". Defaults to None, i.e. no change.
vertical:
An optional string, containing the desired vertical
tick label alignment. Possible values are "center",
"top", "bottom", "baseline". Defaults to None, i.e. no change.
Returns:
None
"""
if not which in ["x", "y"]:
raise ValueError('Parameter which must be one of "x", "y".')
if not hasattr(ax, 'plot'):
raise ValueError("Pass a valid plot in parameter ax.")
if which == "x":
if horizontal is not None:
ax.set_xticklabels(ax.get_xticklabels(),
horizontalalignment=horizontal)
if vertical is not None:
ax.set_xticklabels(ax.get_xticklabels(),
verticalalignment=vertical)
if which == "y":
if horizontal is not None:
ax.set_yticklabels(ax.get_yticklabels(),
horizontalalignment=horizontal)
if vertical is not None:
ax.set_yticklabels(ax.get_yticklabels(),
verticalalignment=vertical)
##################
# Export figures #
##################
def save_png(filename: str, dpi: float = 300) -> None:
"""Exports the currently active figure as PNG file. DPI may be specified.
Args:
filename:
A string, containing the path and filename for exporting.
dpi:
An optional float, specifying the desired DPI. Defaults to 300.
Returns:
None
"""
plt.savefig(filename, dpi=dpi, bbox_inches="tight",
format="png") # pragma: no cover
def save_svg(filename: str) -> None:
"""Exports the currently active figure as SVG file.
Args:
filename:
A string, containing the path and filename for exporting.
Returns:
None
"""
plt.savefig(filename, bbox_inches="tight",
format="svg") # pragma: no cover
def save_pdf(filename: str) -> None:
"""Exports the currently active figure as PDF file.
Args:
filename:
A string, containing the path and filename for exporting.
Returns:
None
"""
plt.savefig(filename, bbox_inches="tight",
format="pdf") # pragma: no cover