Skip to content

Commit

Permalink
GGB!
Browse files Browse the repository at this point in the history
  • Loading branch information
aadi-bh committed Nov 25, 2023
1 parent b81c9fc commit f80a819
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 9 deletions.
4 changes: 2 additions & 2 deletions burgers_filtered.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def saw_tooth(x):
left = np.where(x < pi, 1, 0);
return x * left + (x - 2 * pi) * (1-left) / (2*pi)
return x * left + (x - 2 * pi) * (1-left)
def sin(x):
return np.sin(x)
def step(x):
Expand Down Expand Up @@ -164,7 +164,7 @@ def semigroup_none(dt, k, eps):
filter = create_filter(kk, sigma, {'p':p})
args = (N, M, filter)
u_hat_init = fft(initial_condition(x))
S_half, S = visc(dt, k, eps = 1e-2)
S_half, S = visc(dt, k, eps = 1e-1)
'''
output = solve_ivp(fun = rhs,
t_span = [0, tf],
Expand Down
20 changes: 20 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
from numpy import pi
from numpy.fft import *

xmin = 0
xmax = 2 * pi

def elrk4(SemiGroup,Nonlinear,y0,tinterval,dt,args):
y = y0
t, tf = tinterval
Expand Down Expand Up @@ -95,3 +99,19 @@ def freqs(n):
def cgrid(n, xmin=0, xmax=2*np.pi):
dx = (xmax - xmin) / n
return 0.5 * dx + np.arange(xmin, xmax, dx)

def ifft_at(z, uh):
'''
Evaluates the given truncated fourier series
at each value of x
'''
N = len(uh)
k = freqs(N)
dx = (xmax-xmin)/N
dz = z[1] - z[0]
# Need to shift because the first sample was at dx/2
# TODO Don't completely understand that yet, but it makes sense
y = z - dx/2
xk = np.tensordot(y, k, axes = 0)
return np.exp(1j * xk)@uh / N

22 changes: 20 additions & 2 deletions ggb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy as np
from scipy.special import gegenbauer as ggb
from scipy.special import gamma, factorial
from common import *

'''
Maps interval [a,b] to [-1,1]
'''
def z(y, a, b):
def xi(y, a, b):
return -1 + 2 * (y-a)/(b-a)

'''
Expand All @@ -31,7 +32,24 @@ def gam(n, lam):
r /= gamma(lam) * gamma(2 * lam) * factorial(n) * (n+lam)
return r

def wtd_inner(uh, n, lam, a, b):
x = cgrid(2*len(uh)*(n+1), a, b)
z = xi(x, a, b)
dz = z[1] - z[0]
fx = ifft_at(x, uh)
Cz = C(z, n, lam)
wz = w(z, lam)
# Trapezoidal rule, but w = 0 at +-1
return dz * np.sum(fx * Cz * wz)

x = np.linspace(0, 2*pi, 16)
u = np.ones(x.shape)
uh = fft(u)
print(wtd_inner(uh, 0, 1, 1,2) / gam(0,1))
print(wtd_inner(uh, 1, 1, 1,2))
print(wtd_inner(uh, 2, 1, 1,2))
print(wtd_inner(uh, 3, 1, 1,2))
print(wtd_inner(uh, 4, 1, 1,2))
# TODO
# Need a way to compute the inner product on an arbitrary interval
# Expand it in terms of the ggbs
# Patch it back into the solution.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
v = u[-1]
t = times[-1]
n = len(v)
# ax[0].plot(cgrid(n), ifft(v).real, "+", color= "red", markersize=4)
ax[0].plot(cgrid(n), ifft(v).real, "+", color= "red", markersize=5)
plots.smoothplot(v, ax[0], label=str(n)+f", t={np.round(t, 3)}", linewidth=1)
plots.plot_resolution(v, ax[1], linewidth=0.5, markersize=0.5)

Expand Down
13 changes: 9 additions & 4 deletions plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from common import *
from filters import *

xmin, xmax = 0, 2 * np.pi

def filterplots():
x = np.linspace(-2, 2, 100)
fig, ax = plt.subplots(2,2, figsize=(14,8))
Expand All @@ -20,18 +18,25 @@ def filterplots():
ax[i][j].legend()
fig.savefig("filters.png")

def ggbplots():
return
x = np.linspace(-1, 1)
# TODO

def plot_resolution(c, ax, **kwargs):
k = fftshift(freqs(len(c)))
ax.semilogy(k, np.abs(fftshift(c)), **kwargs)
return

def smoothplot(v, ax,nn=2048, **plotargs):
def smoothplot(v, ax,nn=16, **plotargs):
n = len(v)
w = pad(v, (nn - n)//2)
dy = (xmax - xmin) / n
dz = (xmax - xmin) / nn
y = cgrid(n)
z = np.linspace(y[0], y[-1]+dy, nn, endpoint=True)
z = np.linspace(y[0], y[-1]+dy-dz, nn, endpoint=True)
ax.plot(z, ifft(w).real, **plotargs)
return z, ifft(w)

def convergence_plot(exactfile, filenames, saveas, **kwargs):
for fn in filenames:
Expand Down

0 comments on commit f80a819

Please sign in to comment.