Skip to content

Commit

Permalink
add a prototype of spectrum calculation to eqadet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tjira committed Feb 24, 2024
1 parent dc786c2 commit 98bb2ea
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions education/python/eqadet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def energy(wfn):
# define x and k space
x, k = np.linspace(args.range[0], args.range[1], args.points), 2 * np.pi * np.fft.fftfreq(args.points, dx)

# define the time axis
t = np.linspace(0, args.iters * args.tstep, args.iters + 1)

# define initial wavefunction and potential
psi0, V = eval(args.guess.replace("exp", "np.exp")), eval(args.potential.replace("exp", "np.exp"))

Expand Down Expand Up @@ -93,7 +96,7 @@ def energy(wfn):
D = [[energy(psi) + np.abs(psi)**2 for psi in state] for state in states]

# create the figure and definte tight layout
[fig, ax] = plt.subplots(); plt.tight_layout()
[fig, ax] = plt.subplots(1, 2, figsize=(12, 5)); plt.tight_layout()

# define minimum and maximum x values for plotting
xmin = np.min([np.min([np.min(x[np.abs(psi)**2 > 1e-8]) for psi in Si]) for Si in states])
Expand All @@ -108,15 +111,24 @@ def energy(wfn):
yminimag = min([min([energy(psi) + np.imag(psi).min() for psi in state]) for state in states])

# set limits of the plot
ax.set_xlim(xmin, xmax); ax.set_ylim(np.block([V, yminreal, yminimag]).min(), max([ymaxreal, ymaximag]))
ax[0].set_xlim(xmin, xmax); ax[0].set_ylim(np.block([V, yminreal, yminimag]).min(), max([ymaxreal, ymaximag]))

# plot the potential and initial wavefunctions
ax.plot(x, V); plots = [[ax.plot(x, np.real(state[0]))[0], ax.plot(x, np.imag(state[0]))[0]] for state in states]
ax[0].plot(x, V); plots = [[ax[0].plot(x, np.real(state[0]))[0], ax[0].plot(x, np.imag(state[0]))[0]] for state in states]

# animation update function
def update(j):
for i in range(len(plots)): plots[i][0].set_ydata(energy(states[i][j if j < len(states[i]) else -1]) + np.real(states[i][j if j < len(states[i]) else -1]))
for i in range(len(plots)): plots[i][1].set_ydata(energy(states[i][j if j < len(states[i]) else -1]) + np.imag(states[i][j if j < len(states[i]) else -1]))

# animate the wavefunction
ani = anm.FuncAnimation(fig, update, frames=np.max([len(state) for state in states]), interval=30); plt.show(); plt.close("all") # type: ignore
ani = anm.FuncAnimation(fig, update, frames=np.max([len(state) for state in states]), interval=30) # type: ignore

# calculate the autocorrelation function of a specified state
G = np.array([np.sum(np.conj(states[0][0]) * psi) * dx for psi in states[0]])

# calculate and plot the spectrum
ax[1].plot(np.fft.fftfreq(len(t), args.tstep), np.abs(np.fft.fftshift(np.fft.fft(G))))

# shot the plots
plt.show(); plt.close("all")

0 comments on commit 98bb2ea

Please sign in to comment.