Skip to content

Commit

Permalink
fix AttributeError in BSPlotter.get_plot(): 'Axes' object has no attr…
Browse files Browse the repository at this point in the history
…ibute 'gcf' (materialsproject#3327)
  • Loading branch information
janosh authored Sep 15, 2023
1 parent 9fd9eb5 commit b4e6208
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
18 changes: 9 additions & 9 deletions pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def add_bs(self, bs: BandStructureSymmLine | list[BandStructureSymmLine]) -> Non
# bands
self._nb_bands.extend([b.nb_bands for b in bs])

def _maketicks(self, ax: plt.Axes) -> plt.Axes:
def _make_ticks(self, ax: plt.Axes) -> plt.Axes:
"""Utility private method to add ticks to a band structure."""
ticks = self.get_ticks()
# Sanitize only plot the uniq values
Expand Down Expand Up @@ -689,7 +689,7 @@ def get_plot(
else:
ax.set_ylim(ylim)

self._maketicks(ax)
self._make_ticks(ax)

# Main X and Y Labels
ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
Expand All @@ -708,8 +708,8 @@ def get_plot(
# auto tight_layout when resizing or pressing t
def fix_layout(event):
if (event.name == "key_press_event" and event.key == "t") or event.name == "resize_event":
ax.gcf().tight_layout()
ax.gcf().canvas.draw()
plt.tight_layout()
plt.gcf().canvas.draw()

ax.figure.canvas.mpl_connect("key_press_event", fix_layout)
ax.figure.canvas.mpl_connect("resize_event", fix_layout)
Expand Down Expand Up @@ -984,7 +984,7 @@ def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_c
for el in dictio:
for o in dictio[el]:
ax = plt.subplot(fig_rows + fig_cols + count)
self._maketicks(ax)
self._make_ticks(ax)
for b in range(len(data["distances"])):
for i in range(self._nb_bands):
ax.plot(
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def get_elt_projected_plots(self, zero_to_efermi: bool = True, ylim=None, vbm_cb
count = 1
for el in self._bs.structure.elements:
plt.subplot(220 + count)
self._maketicks(ax)
self._make_ticks(ax)
for b in range(len(data["distances"])):
for i in range(self._nb_bands):
ax.plot(
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None):
spins = [Spin.up]
if self._bs.is_spin_polarized:
spins = [Spin.up, Spin.down]
self._maketicks(ax)
self._make_ticks(ax)
for s in spins:
for b in range(len(data["distances"])):
for i in range(self._nb_bands):
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def get_projected_plots_dots_patom_pmorb(
else:
raise ValueError("The invalid 'num_column' is assigned. It should be an integer.")

ax, shift = self._maketicks_selected(ax, branches)
ax, shift = self._make_ticks_selected(ax, branches)
br = -1
for b in branches:
br += 1
Expand Down Expand Up @@ -2079,7 +2079,7 @@ def orbital_label(list_orbitals):

return dictio_d, dictpa_d

def _maketicks_selected(self, ax: plt.Axes, branches: list[int]) -> tuple[plt.Axes, list[float]]:
def _make_ticks_selected(self, ax: plt.Axes, branches: list[int]) -> tuple[plt.Axes, list[float]]:
"""Utility private method to add ticks to a band structure with selected branches."""
if not ax.figure:
fig = plt.figure() # Create a figure object
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(self, bs):
self._bs = bs
self._nb_bands = self._bs.nb_bands

def _maketicks(self, ax: plt.Axes) -> plt.Axes:
def _make_ticks(self, ax: plt.Axes) -> plt.Axes:
"""Utility private method to add ticks to a band structure."""
ticks = self.get_ticks()
# Sanitize only plot the uniq values
Expand Down Expand Up @@ -325,7 +325,7 @@ def get_plot(self, ylim=None, units="thz") -> plt.Axes:
linewidth=band_linewidth,
)

self._maketicks(ax)
self._make_ticks(ax)

# plot y=0 line
ax.axhline(0, linewidth=1, color="k")
Expand Down Expand Up @@ -434,7 +434,7 @@ def get_proj_plot(

u = freq_units(units)
fig, ax = plt.subplots(figsize=(12, 8), dpi=300)
self._maketicks(ax)
self._make_ticks(ax)

data = self.bs_plot_data()
k_dist = np.array(data["distances"]).flatten()
Expand Down Expand Up @@ -956,7 +956,7 @@ def get_plot_gs(self, ylim=None):
linewidth=2,
)

self._maketicks(ax)
self._make_ticks(ax)

# plot y=0 line
ax.axhline(0, linewidth=1, color="k")
Expand Down

0 comments on commit b4e6208

Please sign in to comment.