Skip to content

Commit

Permalink
Ggb works!
Browse files Browse the repository at this point in the history
  • Loading branch information
aadi-bh committed Nov 25, 2023
1 parent 42d2fa1 commit 0d91274
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
59 changes: 41 additions & 18 deletions ggb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Gegenbauer polynomials -- the Gibbs complimentary basis to Fourier
# Gegenbauer polynomials -- a Gibbs complementary basis to Fourier

import numpy as np
from scipy.special import gegenbauer as ggb
Expand Down Expand Up @@ -38,22 +38,45 @@ def gam(n, lam):
r /= gamma(lam) * gamma(2 * lam) * factorial(n) * (n+lam)
return r

def ip(uh, n, lam, a, b):
z, wts = np.polynomial.legendre.leggauss(10*n+200);
'''
Weighted inner product over [a,b]
#### f must be function of x\in[ a, b],
#### g must be function of z\in[-1, 1].
'''
def wip(f, g, lam, a, b):
z, wts = np.polynomial.legendre.leggauss(10*lam + 200)
x = xi_inv(z, a, b)
fx = ifft_at(x, uh).real
Cz = C(z, n, lam)
fx = f(x)
gz = g(z)
wz = w(z, lam)
return np.sum(wts * fx * Cz * wz)

x = np.linspace(0, 2*pi, 16)
u = np.ones(x.shape)
uh = fft(u)
print(ip(uh, 0, 1, 1,2) / gam(0, 1))
print(ip(uh, 1, 1, 1,2) / gam(1, 1))
print(ip(uh, 2, 1, 1,2) / gam(2, 1))
print(ip(uh, 3, 1, 1,2) / gam(3, 1))
print(ip(uh, 4, 1, 1,2) / gam(4, 1))
# TODO
# Expand it in terms of the ggbs
# Patch it back into the solution.
return np.sum(wts * fx * gz * wz)

def ip_fft(uh, n, lam, a, b):
# f should be a function of x, g must be a function of z
f = lambda x: ifft_at(x, uh)
Cn = ggb(n, lam)
return wip(f, Cn, lam, a, b)

def expand(x, f, L):
'''
Expands f over the GGB polys up to degree L
and returns evaluation at x.
'''
degs = np.arange(0, L+1)
a = x[0]
b = x[-1]
m = np.empty((len(degs), len(x)))
for n in degs:
Cn = ggb(n, L)
m[n] = wip(f, Cn, L, a, b).real / gam(n, L) * Cn(xi(x, a, b))
return np.sum(m, axis = 0)

def expand_fft(x, uh, L):
'''
Expands the function (f = ifft(uh)) in terms of the Gegenbauer polys
and return the evaluation at each x.
'''
f = lambda x: ifft_at(x, uh)
return expand(x, f, L)

# TODO define a function that processes everything
23 changes: 21 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import op
import filters
import plots
import ggb


xmin, xmax = 0, 2 * pi
Expand All @@ -32,8 +33,11 @@
parser.add_argument('--add_visc', type=bool, default=False)
parser.add_argument('--filter', choices=('no_filter', 'exponential', 'cesaro', 'raisedcos', 'lanczos'), default='no_filter')
parser.add_argument('--filterp', type=int, default=1)
parser.add_argument('---ggb', type=bool, default=False)
parser.add_argument('--ggbL', type=int, default=5)
parser.add_argument('--max_lgN', type=int, default=7)
parser.add_argument('--integrator', choices=('solve_ivp', 'elrk4'), default='elrk4')
#parser.add_argument('--integrator', choices=('solve_ivp', 'elrk4'), default='elrk4')
parser.add_argument('--show_markers', type=bool, default=False)
args = parser.parse_args()

if __name__ == '__main__':
Expand Down Expand Up @@ -67,6 +71,7 @@
visc = op.semigroup_heat
tf = max(0, args.Tf)
cfl = min(1, args.cfl)
show_markers = args.show_markers

print("PDE :", args.pde)
print("TF :", tf)
Expand Down Expand Up @@ -132,7 +137,8 @@
v = u[-1]
t = times[-1]
n = len(v)
ax[0].plot(cgrid(n), ifft(v).real, "+", color= "red", markersize=5)
if show_markers:
ax[0].plot(cgrid(n), ifft(v).real, "+", color= "red", markersize=5)
plots.smoothplot(v, ax[0], label=str(n)+f", t={np.round(t, 3)}", linewidth=1)
plots.plot_resolution(v, ax[1], linewidth=0.5, markersize=0.5)

Expand All @@ -144,3 +150,16 @@

plt.close()
print("Done.")

## Gegenbauer test
u = u[-1]
x = cgrid(len(u))
# plots.smoothplot(u, plt)
plt.plot(x, ifft(u).real)
ai = np.where(x < 3, True, False)
y = x[ai]
v = ifft(u)
v[ai] = ggb.expand_fft(y, u, 3)
plt.plot(x, v.real)
# plots.smoothplot(fft(v), plt)
plt.show()
2 changes: 1 addition & 1 deletion plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot_resolution(c, ax, **kwargs):
ax.semilogy(k, np.abs(fftshift(c)), **kwargs)
return

def smoothplot(v, ax,nn=16, **plotargs):
def smoothplot(v, ax,nn=2048, **plotargs):
n = len(v)
w = pad(v, (nn - n)//2)
dy = (xmax - xmin) / n
Expand Down

0 comments on commit 0d91274

Please sign in to comment.