Skip to content

Commit

Permalink
Merge pull request #106 from dgaines2/discontinuous-qpoints
Browse files Browse the repository at this point in the history
Discontinuous qpoints
  • Loading branch information
kbspooner authored Dec 20, 2024
2 parents d6cda08 + 0b5b357 commit 1bf61b1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
22 changes: 17 additions & 5 deletions tp/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,11 +1595,23 @@ def converge_phonons(band_yaml, bandmin, bandmax, colour, alpha, linestyle,
dosax = ax[1]
ax = ax[0]

tp.plot.phonons.add_multi(ax, data, colour=colour, linestyle=linestyle,
marker=marker, label=label, bandmin=bandmin,
bandmax=bandmax, alpha=alpha,
xmarkkwargs={'color': xmarkcolour,
'linestyle': xmarklinestyle})
if len(band_yaml) == 1:
try:
colour = mpl.cm.get_cmap(colour)([0])
except ValueError:
pass
tp.plot.phonons.add_dispersion(ax, data[0], colour=colour,
linestyle=linestyle[0], marker=marker[0],
label=label, bandmin=bandmin,
bandmax=bandmax, alpha=alpha,
xmarkkwargs={'color': xmarkcolour,
'linestyle': xmarklinestyle})
else:
tp.plot.phonons.add_multi(ax, data, colour=colour, linestyle=linestyle,
marker=marker, label=label, bandmin=bandmin,
bandmax=bandmax, alpha=alpha,
xmarkkwargs={'color': xmarkcolour,
'linestyle': xmarklinestyle})
if dos is not None:
dosdata = tp.data.load.phonopy_dos(dos, poscar, atoms)
tp.plot.frequency.add_dos(dosax, dosdata, projected=projected,
Expand Down
24 changes: 20 additions & 4 deletions tp/plot/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,17 @@ def add_dispersion(ax, data, sdata=None, bandmin=None, bandmax=None, main=True,

# plotting

# avoid connecting bands at disconnected q-points
split_indices = [0, *np.where(np.diff(x) == 0)[0] + 1, len(x)]
for n in range(len(f[0])):
ax.plot(x, f[:,n], color=colour[n], linestyle=linestyle[n],
marker=marker[n], label=label[n], **kwargs)
for i in range(len(split_indices)-1):
starting_index = split_indices[i]
ending_index = split_indices[i+1]
x_i = x[starting_index:ending_index]
f_ni = f[starting_index:ending_index, n]
label_ni = label[n] if i == 0 else None
ax.plot(x_i, f_ni, color=colour[n], linestyle=linestyle[n],
label=label_ni, marker=marker[n], **kwargs)

# axes formatting

Expand Down Expand Up @@ -577,13 +585,21 @@ def add_alt_dispersion(ax, data, pdata, quantity, bandmin=None, bandmax=None,

# plotting

# avoid connecting bands at disconnected q-points
split_indices = [0, *np.where(np.diff(x2) == 0)[0] + 1, len(x2)]
for n in range(len(y2[0])):
if scatter:
ax.scatter(x2, y2[:,n], color=colour[n], linestyle=linestyle[n],
label=label[n], marker=marker[n], **kwargs)
else:
ax.plot(x2, y2[:,n], color=colour[n], linestyle=linestyle[n],
label=label[n], marker=marker[n], **kwargs)
for i in range(len(split_indices)-1):
starting_index = split_indices[i]
ending_index = split_indices[i+1]
x2_i = x2[starting_index:ending_index]
y2_ni = y2[starting_index:ending_index, n]
label_ni = label[n] if i == 0 else None
ax.plot(x2_i, y2_ni, color=colour[n], linestyle=linestyle[n],
label=label_ni, marker=marker[n], **kwargs)

# axes formatting

Expand Down

0 comments on commit 1bf61b1

Please sign in to comment.