Skip to content

Commit

Permalink
Much progress for filters
Browse files Browse the repository at this point in the history
  • Loading branch information
aadi-bh committed Nov 24, 2023
1 parent 1f5ead3 commit de34fc3
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 43 deletions.
23 changes: 23 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,28 @@ def plot_resolution(c, ax, kwargs):
ax.semilogy(k, np.abs(fftshift(c)), **kwargs)
return

def pad(c, m):
# ASSUME c is fft(u)
N = len(c)
newN = 2*m + N
r = fftshift(c)
r = np.r_[np.zeros(m), r, np.zeros(m)]
r *= newN / N
r = ifftshift(r)
return r

def unpad(c, m):
N = len(c)
newN = N - 2 * m
r = fftshift(c)
r = r[m:m + newN]
r *= newN / N
r = ifftshift(r)
return r

def freqs(n):
return fftfreq(n, 1./ (n))

def cgrid(n, xmin=0, xmax=2*np.pi):
dx = (xmax - xmin) / n
return 0.5 * dx + np.arange(xmin, xmax, dx)
23 changes: 14 additions & 9 deletions filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
import numpy as np

def exponential(eta, p=1):
return np.exp(-35* np.power(eta, 2*p))
return np.exp(-35 * np.power(eta, 2*p))

def cesaro(eta):
return 1 - eta
def cesaro(eta, **kwargs):
r = 1 - np.abs(eta)
r *= np.where(np.abs(eta) > 1, 0, 1)
return r

def raisedcos(eta):
return 0.5 * (1 + np.cos(np.pi * eta))
def raisedcos(eta, **kwargs):
r = 0.5 * (1 + np.cos(np.pi * eta))
r *= np.where(np.abs(eta) > 1, 0, 1)
return r

def lanczos(eta):
return np.sin(np.pi * eta) / (np.pi * eta)
def lanczos(eta, **kwargs):
r = np.sinc(np.pi * eta)
return r

def no_filter(eta):
return 1.0
def no_filter(eta, **kwargs):
return np.ones(len(eta))
1 change: 1 addition & 0 deletions ic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Initial conditions
import numpy as np
from numpy import pi

def saw_tooth(x):
left = np.where(x <= pi, 1, 0);
Expand Down
49 changes: 35 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
parser.add_argument('--Tf', type=float, help='Final time', default=1.0)
parser.add_argument('--pde', choices=('linadv', 'burgers'), default='burgers')
parser.add_argument('--add_visc', type=bool, default=False)
parser.add_argument('--use_filter', type=bool, default=False)
parser.add_argument('--filter', choices=('no_filter', 'exponential', 'cesaro', 'raisedcos', 'lanczos'), default='no_filter')
parser.add_argument('--filterp', type=int, default=1)
parser.add_argument('--max_lgN', type=int, default=7)
parser.add_argument('--integrator', choices=('solve_ivp', 'elrk4'), default='elrk4')
args = parser.parse_args()
Expand All @@ -38,7 +39,7 @@
rhs = op.linadv
initial_condition = ic.sin
visc = op.semigroup_none
USE_FILTER = False
sigma = filters.no_filter
tf = 2
cfl = 0.5
if args.pde == 'burgers':
Expand All @@ -51,35 +52,44 @@
initial_condition = ic.step
elif args.ic == 'bump':
initial_condition = ic.bump
if args.filter == 'no_filter':
sigma = filters.no_filter;
elif args.filter == 'exponential':
sigma = filters.exponential
elif args.filter == 'cesaro':
sigma = filters.cesaro
elif args.filter == 'raisedcos':
sigma = filters.raisedcos
elif args.filter == 'lanczos':
sigma = filters.lanczos
if args.add_visc == True:
visc = op.semigroup_heat
tf = max(0, args.Tf)
cfl = min(1, args.cfl)
USE_FILTER = args.use_filter

print("PDE :", args.pde)
print("TF :", tf)
print("CFL :", cfl)
print("IC :", args.ic)
print("FILTER:", args.use_filter)
print("FILTER:", args.filter)
print("VISC :", args.add_visc)

plotname = f"{args.pde}-visc{str(args.add_visc)}-{tf}-{args.ic}-{args.filter}.png"
sols = []
for i in range(0, args.max_lgN - 4 + 1):
N = np.power(2, i + 4);
N = 64
print(f"N={N}")
M = 3 * N // 2;
m = M // 2;
NN = (2 * m) + N
dx = (xmax-xmin) / N
dt = (cfl * dx)
x = np.arange(xmin, xmax, dx)
x = cgrid(N)
k = freqs(N)
kk = freqs(NN)
filter = np.ones(len(kk))
p = i
if USE_FILTER == True:
filter = create_filter(kk, sigma, {'p':p})
filter = create_filter(kk, sigma, {'p':p})
args = (N, M, filter)
u_hat_init = fft(initial_condition(x))
S_half, S = visc(dt, k, eps = 1e-2)
Expand All @@ -97,6 +107,8 @@
filename = f"{N}.txt"
np.savetxt(filename, np.vstack((k, u[-1].real)))
print("Saved solution to " + filename)
print("Breaking for debug")
break

# PLOT
fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (14, 8), width_ratios=[3,1])
Expand All @@ -107,18 +119,27 @@
god = np.loadtxt("burg3_GOD_3.txt").transpose()
elif tf == 0.2:
god = np.loadtxt("burg3_GOD_1.txt").transpose()
ax[0].plot(god[0] * 2 * np.pi, god[1], 'ko', markersize=0.1, label="Godunov flux")
ax[0].plot(god[0] * 2 * np.pi, god[1], 'ko', markersize=0.8, label="Godunov flux")

fig.tight_layout()
nn = 2048
# x = np.linspace(xmin, xmax, 2048, endpoint=False)
for (times, u) in sols:
ax[0].plot(x, ifft(u[-1]).real, label=str(N)+f", t={np.round(times[-1], 3)}")
plot_resolution(u[-1], ax[1], {'linewidth':0.5, 'markersize':0.5})
fig, ax = plt.subplots(nrows=1, ncols=2, figsize = (14, 8), width_ratios=[3,1])
v = u[-1]
t = times[-1]
n = len(v)
dx = (xmax - xmin) / n
y = np.arange(xmin, xmax, dx) + 0.5 * dx
# y = np.linspace(xmin, xmax, len(v), endpoint=False)
ax[0].plot(y, ifft(v).real, label=str(n)+f", t={np.round(t, 3)}")
plot_resolution(v, ax[1], {'linewidth':0.5, 'markersize':0.5})
w = pad(v, (nn - n)//2)
z = np.linspace(dx/2, xmax - dx/2, 2048)
ax[0].plot(z, ifft(w).real, label=str(n))
x = np.linspace(xmin, xmax, 1000)
ax[0].plot(x, initial_condition(x), linewidth=0.1, color='k', label="init")
ax[0].legend()
fig.savefig(f"{rhs}-{str(USE_FILTER)}.png")
fig.savefig(plotname)

plt.close()
print("Done.")

22 changes: 2 additions & 20 deletions op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,8 @@
from numpy import pi
from common import *

def pad(c, m):
# ASSUME c is fft(u) / len(u)
N = len(c)
newN = 2*m + N
r = fftshift(c)
r = np.r_[np.zeros(m), r, np.zeros(m)]
r *= newN / N
r = ifftshift(r)
return r

def unpad(c, m):
N = len(c)
newN = N - 2 * m
r = fftshift(c)
r = r[m:m + newN]
r *= newN / N
r = ifftshift(r)
return r

def linadv(t, u_hat, N, M, filter, a=1):

def linadv(t, u_hat, N, M, filter, a= 2 * pi):
NN = (2 * M//2) + N
u_hat = pad(u_hat, M//2)
kk = freqs(NN)
Expand Down

0 comments on commit de34fc3

Please sign in to comment.