Skip to content

Commit

Permalink
Vlasov1d - moment scalars (#25)
Browse files Browse the repository at this point in the history
* moment scalars for vlasov 1d

* ln

* entropy moments
spatio-temporal collision frequencies

* tested collisions

* working wavepackets and absorbing boundaries
  • Loading branch information
joglekara authored Dec 22, 2023
1 parent 39231b9 commit bf06f2b
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 54 deletions.
32 changes: 30 additions & 2 deletions adept/vlasov1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,18 @@ def post_process(result, cfg: Dict, td: str):
os.makedirs(os.path.join(td, "plots"), exist_ok=True)
os.makedirs(os.path.join(td, "plots", "fields"), exist_ok=True)
os.makedirs(os.path.join(td, "plots", "fields", "lineouts"), exist_ok=True)
os.makedirs(os.path.join(td, "plots", "fields", "logplots"), exist_ok=True)

os.makedirs(os.path.join(td, "plots", "scalars"), exist_ok=True)

binary_dir = os.path.join(td, "binary")
os.makedirs(binary_dir)
# merge
# flds_paths = [os.path.join(flds_path, tf) for tf in flds_list]
# arr = xarray.open_mfdataset(flds_paths, combine="by_coords", parallel=True)
for k in result.ys.keys():
if k.startswith("field"):
fields_xr = store_fields(cfg, td, result.ys[k], result.ts[k], k)
fields_xr = store_fields(cfg, binary_dir, result.ys[k], result.ts[k], k)
t_skip = int(fields_xr.coords["t"].data.size // 8)
t_skip = t_skip if t_skip > 1 else 1
tslice = slice(0, -1, t_skip)
Expand All @@ -279,12 +285,34 @@ def post_process(result, cfg: Dict, td: str):
plt.savefig(os.path.join(td, "plots", "fields", f"spacetime-{nm[7:]}.png"), bbox_inches="tight")
plt.close()

np.log10(np.abs(fld)).plot()
plt.savefig(
os.path.join(td, "plots", "fields", "logplots", f"spacetime-log-{nm[7:]}.png"), bbox_inches="tight"
)
plt.close()

fld[tslice].T.plot(col="t", col_wrap=4)
plt.savefig(os.path.join(td, "plots", "fields", "lineouts", f"{nm[7:]}.png"), bbox_inches="tight")
plt.close()

elif k.startswith("default"):
scalars_xr = xarray.Dataset(
{k: xarray.DataArray(v, coords=(("t", result.ts["default"]),)) for k, v in result.ys["default"].items()}
)
scalars_xr.to_netcdf(os.path.join(binary_dir, f"scalars-t={round(scalars_xr.coords['t'].data[-1], 4)}.nc"))

for nm, srs in scalars_xr.items():
fig, ax = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True)
srs.plot(ax=ax[0])
ax[0].grid()
np.log10(np.abs(srs)).plot(ax=ax[1])
ax[1].grid()
ax[1].set_ylabel("$log_{10}$(|" + nm + "|)")
fig.savefig(os.path.join(td, "plots", "scalars", f"{nm}.png"), bbox_inches="tight")
plt.close()

f_xr = store_f(cfg, result.ts, td, result.ys)

mlflow.log_metrics({"postprocess_time_min": round((time() - t0) / 60, 3)})

return {"fields": fields_xr, "dists": f_xr}
return {"fields": fields_xr, "dists": f_xr, "scalars": scalars_xr}
43 changes: 38 additions & 5 deletions adept/vlasov1d/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import diffrax

from adept.vlasov1d.pushers import field, fokker_planck, vlasov
from adept.tf1d.pushers import get_envelope


class Stepper(diffrax.Euler):
Expand Down Expand Up @@ -124,24 +125,56 @@ def __init__(self, cfg):
self.compute_charges = partial(jnp.trapz, dx=cfg["grid"]["dv"], axis=1)
self.dt = self.cfg["grid"]["dt"]
self.driver = field.Driver(cfg["grid"]["x"])
self.nu_prof = cfg["nuee"]
self.kr_prof = 0.0

def nu_prof(self, t, nu_args):
t_L = nu_args["time"]["center"] - nu_args["time"]["width"] * 0.5
t_R = nu_args["time"]["center"] + nu_args["time"]["width"] * 0.5
t_wL = nu_args["time"]["rise"]
t_wR = nu_args["time"]["rise"]
x_L = nu_args["space"]["center"] - nu_args["space"]["width"] * 0.5
x_R = nu_args["space"]["center"] + nu_args["space"]["width"] * 0.5
x_wL = nu_args["space"]["rise"]
x_wR = nu_args["space"]["rise"]

nu_time = get_envelope(t_wL, t_wR, t_L, t_R, t)
if nu_args["time"]["bump_or_trough"] == "trough":
nu_time = 1 - nu_time
nu_time = nu_args["time"]["baseline"] + nu_args["time"]["bump_height"] * nu_time

nu_prof = get_envelope(x_wL, x_wR, x_L, x_R, self.cfg["grid"]["x"])
if nu_args["space"]["bump_or_trough"] == "trough":
nu_prof = 1 - nu_prof
nu_prof = nu_args["space"]["baseline"] + nu_args["space"]["bump_height"] * nu_prof

return nu_time * nu_prof

def __call__(self, t, y, args):
"""
This is just a wrapper around a Vlasov-Poisson + Fokker-Planck timestep
:param loop_carry:
:param current_params:
:param t:
:param y:
:param args:
:return:
"""

de = [self.driver(t + dt, args) for dt in self.vpfp.vlasov_poisson.dt_array]
dex = [val[0] for val in de]
djy = de[0][1]

if self.cfg["terms"]["fokker_planck"]["is_on"]:
nu_fp_prof = self.nu_prof(t=t, nu_args=args["terms"]["fokker_planck"])
else:
nu_fp_prof = None

if self.cfg["terms"]["krook"]["is_on"]:
nu_K_prof = self.nu_prof(t=t, nu_args=args["terms"]["krook"])
else:
nu_K_prof = None

electron_density_n = self.compute_charges(y["electron"])
e, f, force, pond = self.vpfp(y["electron"], y["a"], y["e"], dex, self.nu_prof, self.kr_prof)
e, f, force, pond = self.vpfp(y["electron"], y["a"], y["e"], dex, nu_fp_prof, nu_K_prof)
electron_density_np1 = self.compute_charges(f)

a = self.wave_solver(
Expand Down
35 changes: 18 additions & 17 deletions adept/vlasov1d/pushers/fokker_planck.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@ def __init__(self, cfg):
self.td_solver = TridiagonalSolver(self.cfg)

def __init_fp_operator__(self):
# if self.cfg["solver"]["fp_operator"] == "lenard_bernstein":
# return LenardBernstein(self.cfg)
# elif self.cfg["solver"]["fp_operator"] == "dougherty":
return Dougherty(self.cfg)
# else:
# raise NotImplementedError

def __call__(self, nu_fp: jnp.float64, nu_K: jnp.float64, f: jnp.ndarray, dt: jnp.float64) -> jnp.ndarray:
if np.any(self.cfg["nuee"] > 0.0):
if self.cfg["terms"]["fokker_planck"]["type"].casefold() == "lenard_bernstein":
return LenardBernstein(self.cfg)
elif self.cfg["terms"]["fokker_planck"]["type"].casefold() == "dougherty":
return Dougherty(self.cfg)
else:
raise NotImplementedError

def __call__(self, nu_fp: jnp.ndarray, nu_K: jnp.ndarray, f: jnp.ndarray, dt: jnp.float64) -> jnp.ndarray:
if self.cfg["terms"]["fokker_planck"]["is_on"]:
# The three diagonals representing collision operator for all x
cee_a, cee_b, cee_c = self.fp(nu=nu_fp, f_xv=f, dt=dt)
# Solve over all x
f = self.td_solver(cee_a, cee_b, cee_c, f)

# if (np.any(self.cfg["grid"]["kr_prof"] > 0.0)) and (np.any(self.cfg["grid"]["kt_prof"] > 0.0)):
# f = self.krook(nu_K, f, dt)
if self.cfg["terms"]["krook"]["is_on"]:
f = self.krook(nu_K, f, dt)

return f


Expand Down Expand Up @@ -72,9 +73,9 @@ def __call__(
"""

v0t_sq = self.vx_moment(f_xv * self.v[None, :] ** 2.0)
a = nu * dt * (-v0t_sq[:, None] / self.dv**2.0 + jnp.roll(self.v, 1)[None, :] / 2 / self.dv)
b = 1.0 + nu * dt * self.ones * (2.0 * v0t_sq[:, None] / self.dv**2.0)
c = nu * dt * (-v0t_sq[:, None] / self.dv**2.0 - jnp.roll(self.v, -1)[None, :] / 2 / self.dv)
a = nu[:, None] * dt * (-v0t_sq[:, None] / self.dv**2.0 + jnp.roll(self.v, 1)[None, :] / 2 / self.dv)
b = 1.0 + nu[:, None] * dt * self.ones * (2.0 * v0t_sq[:, None] / self.dv**2.0)
c = nu[:, None] * dt * (-v0t_sq[:, None] / self.dv**2.0 - jnp.roll(self.v, -1)[None, :] / 2 / self.dv)
return a, b, c


Expand All @@ -101,13 +102,13 @@ def __call__(
v0t_sq = self.vx_moment(f_xv * (self.v[None, :] - vbar[:, None]) ** 2.0)

a = (
nu
nu[:, None]
* dt
* (-v0t_sq[:, None] / self.dv**2.0 + (jnp.roll(self.v, 1)[None, :] - vbar[:, None]) / 2.0 / self.dv)
)
b = 1.0 + nu * dt * self.ones * (2.0 * v0t_sq[:, None] / self.dv**2.0)
b = 1.0 + nu[:, None] * dt * self.ones * (2.0 * v0t_sq[:, None] / self.dv**2.0)
c = (
nu
nu[:, None]
* dt
* (-v0t_sq[:, None] / self.dv**2.0 - (jnp.roll(self.v, -1)[None, :] - vbar[:, None]) / 2.0 / self.dv)
)
Expand Down
44 changes: 35 additions & 9 deletions adept/vlasov1d/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import xarray as xr


def store_fields(cfg: Dict, td: str, fields: Dict, this_t: np.ndarray, prefix: str) -> xr.Dataset:
def store_fields(cfg: Dict, binary_dir: str, fields: Dict, this_t: np.ndarray, prefix: str) -> xr.Dataset:
"""
Stores fields to netcdf
Expand All @@ -19,8 +19,6 @@ def store_fields(cfg: Dict, td: str, fields: Dict, this_t: np.ndarray, prefix: s
:param this_t:
:return:
"""
binary_dir = os.path.join(td, "binary")
os.makedirs(binary_dir, exist_ok=True)

if any(x in ["x", "kx"] for x in cfg["save"][prefix].keys()):
crds = set(cfg["save"][prefix].keys()) - {"t", "func"}
Expand Down Expand Up @@ -103,14 +101,16 @@ def store_f(cfg: Dict, this_t: Dict, td: str, ys: Dict) -> xr.Dataset:
def get_field_save_func(cfg, k):
if {"t"} == set(cfg["save"][k].keys()):

def _calc_moment_(inp):
return jnp.trapz(inp, dx=cfg["grid"]["dv"], axis=1)

def fields_save_func(t, y, args):
temp = {
"n": jnp.trapz(y["electron"], dx=cfg["grid"]["dv"], axis=1),
"v": jnp.trapz(y["electron"] * cfg["grid"]["v"][None, :], dx=cfg["grid"]["dv"], axis=1),
}
temp = {"n": _calc_moment_(y["electron"]), "v": _calc_moment_(y["electron"] * cfg["grid"]["v"][None, :])}
v_m_vbar = cfg["grid"]["v"][None, :] - temp["v"][:, None]
temp["p"] = jnp.trapz(y["electron"] * v_m_vbar**2.0, dx=cfg["grid"]["dv"], axis=1)
temp["q"] = jnp.trapz(y["electron"] * v_m_vbar**3.0, dx=cfg["grid"]["dv"], axis=1)
temp["p"] = _calc_moment_(y["electron"] * v_m_vbar**2.0)
temp["q"] = _calc_moment_(y["electron"] * v_m_vbar**3.0)
temp["-flogf"] = _calc_moment_(y["electron"] * jnp.log(jnp.abs(y["electron"])))
temp["f^2"] = _calc_moment_(y["electron"] * y["electron"])
temp["e"] = y["e"]
temp["de"] = y["de"]

Expand Down Expand Up @@ -172,4 +172,30 @@ def get_save_quantities(cfg: Dict) -> Dict:
elif k.startswith("electron"):
cfg["save"][k]["func"] = get_dist_save_func(cfg, k)

cfg["save"]["default"] = {"t": {"ax": cfg["grid"]["t"]}, "func": get_default_save_func(cfg)}

return cfg


def get_default_save_func(cfg):
v = cfg["grid"]["v"][None, :]
dv = cfg["grid"]["dv"]

def _calc_mean_moment_(inp):
return jnp.mean(jnp.trapz(inp, dx=dv, axis=1))

def save(t, y, args):
scalars = {
"mean_P": _calc_mean_moment_(y["electron"] * v**2.0),
"mean_j": _calc_mean_moment_(y["electron"] * v),
"mean_n": _calc_mean_moment_(y["electron"]),
"mean_q": _calc_mean_moment_(y["electron"] * v**3.0),
"mean_-flogf": _calc_mean_moment_(-jnp.log(jnp.abs(y["electron"])) * jnp.abs(y["electron"])),
"mean_f2": _calc_mean_moment_(y["electron"] * y["electron"]),
"mean_de2": jnp.mean(y["de"] ** 2.0),
"mean_e2": jnp.mean(y["e"] ** 2.0),
}

return scalars

return save
45 changes: 41 additions & 4 deletions configs/vlasov-1d/epw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ density:
grid:
c_light: 10
dt: 0.5
nv: 512
nx: 32
nv: 4096
nx: 128
tmin: 0.
tmax: 480.0
vmax: 6.4
Expand Down Expand Up @@ -56,7 +56,7 @@ mlflow:
drivers:
ex:
'0':
a0: 1.e-3
a0: 1.e-4
k0: 0.3
t_center: 40.0
t_rise: 5.0
Expand All @@ -68,4 +68,41 @@ drivers:
x_width: 4000000.0
ey: {}

nuee: 1.0e-5
terms:
fokker_planck:
is_on: True
type: Dougherty
time:
baseline: 1.0e-5
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
krook:
is_on: False
time:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
space:
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
slope: 0.0
bump_height: 0.0
width: 100000.0
Loading

0 comments on commit bf06f2b

Please sign in to comment.