Skip to content

Commit

Permalink
TPD source term (#24)
Browse files Browse the repository at this point in the history
Ended up needing

Noise - now in input
Units - now in input
Automatic collision frequency - now in input
  • Loading branch information
joglekara authored Dec 16, 2023
1 parent 6be3c71 commit 39231b9
Show file tree
Hide file tree
Showing 35 changed files with 403 additions and 231 deletions.
10 changes: 5 additions & 5 deletions adept/lpse2d/core/epw.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, cfg):
self.wp0 = cfg["plasma"]["wp0"]
self.n0 = np.sqrt(self.wp0)
self.w0 = cfg["drivers"]["E0"]["w0"]
self.nuei = cfg["plasma"]["nu_ei"]
self.nuei = -cfg["units"]["derived"]["nuei_norm"]
self.kx = cfg["grid"]["kx"]
self.ky = cfg["grid"]["ky"]
self.dx = cfg["grid"]["dx"]
Expand Down Expand Up @@ -172,7 +172,7 @@ def calc_tpd_source_step(self, phi: jax.Array, e0: jax.Array, nb: jax.Array, t:
)
return coeff * (term1 + term2)

def update_potential(self, y):
def update_potential(self, t, y):
# do the equation first --- this is equation 54 so far
if self.cfg["terms"]["epw"]["linear"]:
osc_term, damping_term = self.calc_linear_step(y["temperature"])
Expand All @@ -185,8 +185,8 @@ def update_potential(self, y):
eh_x = eh_x * jnp.exp(density_step * self.dt)
new_phi = self.get_phi_from_eh(eh_x)

if self.cfg["terms"]["tpd"]["source"]:
new_phi += self.dt * self.calc_tpd_source_step(y["phi"], y["e0"], y["nb"], y["t"])
if self.cfg["terms"]["epw"]["source"]["tpd"]:
new_phi += self.dt * self.calc_tpd_source_step(y["phi"], y["e0"], y["nb"], t)

else:
raise NotImplementedError("The linear term is necessary to run the code")
Expand All @@ -208,7 +208,7 @@ def update_delta(self, t, y, args):

def __call__(self, t, y, args):
# push the equation of motion for the potential
y["phi"] = self.update_potential(y)
y["phi"] = self.update_potential(t, y)

if ("E2" in self.cfg["drivers"].keys()) or self.cfg["terms"]["epw"]["trapping"]["active"]:
eh = self.get_eh_x(y["phi"])
Expand Down
159 changes: 97 additions & 62 deletions adept/lpse2d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@


import matplotlib.pyplot as plt
import jax
import jax, pint
from jax import numpy as jnp
import numpy as np
from scipy import constants
import equinox as eqx
from diffrax import ODETerm
import xarray as xr
from plasmapy.formulary.collisions.frequencies import fundamental_electron_collision_freq

from adept.lpse2d.core import integrator, driver


def get_derived_quantities(cfg_grid: Dict) -> Dict:
def get_derived_quantities(cfg: Dict) -> Dict:
"""
This function just updates the config with the derived quantities that are only integers or strings.
Expand All @@ -24,6 +25,8 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
:param cfg_grid:
:return:
"""
cfg_grid = cfg["grid"]

cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]
cfg_grid["dy"] = cfg_grid["ymax"] / cfg_grid["ny"]

Expand All @@ -37,7 +40,11 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
else:
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

return cfg_grid
cfg = get_more_units(cfg)

cfg["grid"] = cfg_grid

return cfg


def get_save_quantities(cfg: Dict) -> Dict:
Expand Down Expand Up @@ -130,33 +137,48 @@ def init_state(cfg: Dict, td=None) -> Dict:
:return: state: Dict
"""

e0 = jnp.zeros((cfg["grid"]["nx"], cfg["grid"]["ny"], 2), dtype=jnp.complex128)
phi = jnp.zeros((cfg["grid"]["nx"], cfg["grid"]["ny"]), dtype=jnp.complex128)
# e0 = jnp.zeros((cfg["grid"]["nx"], cfg["grid"]["ny"], 2), dtype=jnp.complex128)
# phi = jnp.zeros((cfg["grid"]["nx"], cfg["grid"]["ny"]), dtype=jnp.complex128)
# phi += (
# 1e-3
# * jnp.exp(-(((cfg["grid"]["x"][:, None] - 2000) / 400.0) ** 2.0))
# * jnp.exp(-1j * 0.2 * cfg["grid"]["x"][:, None])
# )
# e0 = jnp.concatenate(
# [jnp.exp(1j * cfg["drivers"]["E0"]["k0"] * cfg["grid"]["x"])[:, None] for _ in range(cfg["grid"]["ny"])], axis=-1
# )
# e0 = jnp.concatenate([e0[:, :, None], jnp.zeros_like(e0)[:, :, None]], axis=-1)
# e0 *= cfg["drivers"]["E0"]["e0"]
e0 = jnp.concatenate(
[jnp.exp(1j * cfg["drivers"]["E0"]["k0"] * cfg["grid"]["x"])[:, None] for _ in range(cfg["grid"]["ny"])],
axis=-1,
)
e0 = jnp.concatenate([e0[:, :, None], jnp.zeros_like(e0)[:, :, None]], axis=-1)
e0 *= cfg["drivers"]["E0"]["a0"]

if cfg["density"]["noise"]["type"] == "uniform":
random_amps_x = np.random.uniform(
cfg["density"]["noise"]["min"], cfg["density"]["noise"]["max"], cfg["grid"]["nx"]
)
random_amps_y = np.random.uniform(
cfg["density"]["noise"]["min"], cfg["density"]["noise"]["max"], cfg["grid"]["ny"]
)
elif cfg["density"]["noise"]["type"] == "normal":
loc = 0.5 * (cfg["density"]["noise"]["min"] + cfg["density"]["noise"]["max"])
scale = 1.0
random_amps_x = np.random.normal(loc, scale, cfg["grid"]["nx"])
random_amps_y = np.random.normal(loc, scale, cfg["grid"]["ny"])

random_amps_x = 1.0e-12 * np.random.uniform(0.1, 1, cfg["grid"]["nx"])
random_amps_y = 1.0e-12 * np.random.uniform(0.1, 1, cfg["grid"]["ny"])
else:
raise NotImplementedError

# phi = jnp.sum(random_amps_x * jnp.exp(1j * cfg["grid"]["kx"][None, :] * cfg["grid"]["x"][:, None]), axis=-1)[
# :, None
# ]
# phi += jnp.sum(random_amps_y * jnp.exp(1j * cfg["grid"]["ky"][None, :] * cfg["grid"]["y"][:, None]), axis=-1)[
# None, :
# ]
# phi = jnp.fft.fft2(phi)
phi = jnp.sum(random_amps_x * jnp.exp(1j * cfg["grid"]["kx"][None, :] * cfg["grid"]["x"][:, None]), axis=-1)[
:, None
]
phi += jnp.sum(random_amps_y * jnp.exp(1j * cfg["grid"]["ky"][None, :] * cfg["grid"]["y"][:, None]), axis=-1)[
None, :
]
phi = jnp.fft.fft2(phi)

state = {
"e0": e0,
"nb": (0.8 + 0.4 * cfg["grid"]["x"] / cfg["grid"]["xmax"])[:, None] * np.ones_like(phi, dtype=np.float64),
"nb": (cfg["density"]["offset"] + cfg["density"]["slope"] * cfg["grid"]["x"] / cfg["grid"]["xmax"])[:, None]
* np.ones_like(phi, dtype=np.float64),
"temperature": jnp.ones_like(e0[..., 0], dtype=jnp.float64),
"dn": jnp.zeros_like(e0[..., 0], dtype=jnp.float64),
"phi": phi,
Expand Down Expand Up @@ -209,50 +231,51 @@ def init_state(cfg: Dict, td=None) -> Dict:
return {k: v.view(dtype=np.float64) for k, v in state.items()}


def calc_e0(cfg):
e_laser = np.sqrt(2.0 * (float(cfg["drivers"]["E0"]["intensity"]) * 100) / constants.c / constants.epsilon_0)
e_norm = constants.m_e * cfg["norms"]["velocity"] * cfg["norms"]["frequency"] / constants.e
def get_more_units(cfg: Dict):
"""
cfg["norms"]["electric field"] = e_norm
cfg["norms"]["laser field"] = e_laser
:type cfg: object
"""

cfg["drivers"]["E0"]["e0"] = e_laser / e_norm
cfg["drivers"]["E0"]["k0"] = np.sqrt(
(cfg["drivers"]["E0"]["w0"] ** 2.0 - cfg["plasma"]["wp0"] ** 2.0) / cfg["norms"]["c"] ** 2.0
)
ureg = pint.UnitRegistry()
_Q = ureg.Quantity
import astropy.units as u

print("laser parameters: ")
print(f'w0 = {round(cfg["drivers"]["E0"]["w0"],4)}')
print(f'k0 = {round(cfg["drivers"]["E0"]["k0"],4)}')
print(f"a0 = {round(e_laser / e_norm, 4)}")
print()
n0 = _Q(cfg["units"]["normalizing density"]).to("1/cc")
wp0 = np.sqrt(n0 * ureg.e**2.0 / (ureg.m_e * ureg.epsilon_0)).to("rad/s")
T0 = _Q(cfg["units"]["normalizing temperature"]).to("eV")
v0 = np.sqrt(2.0 * T0 / ureg.m_e).to("m/s")
c_light = _Q(1.0 * ureg.c).to("m/s") / v0

return cfg
_nuei_ = fundamental_electron_collision_freq(
T_e=(Te := _Q(cfg["units"]["electron temperature"]).to("eV")).magnitude * u.eV,
n_e=n0.to("1/m^3").magnitude / u.m**3,
ion=f'{cfg["units"]["gas fill"]} {cfg["units"]["ionization state"]}+',
).value
cfg["units"]["derived"]["nuei"] = _Q(f"{_nuei_} Hz")
cfg["units"]["derived"]["nuei_norm"] = (cfg["units"]["derived"]["nuei"].to("rad/s") / wp0).magnitude

lambda_0 = _Q(cfg["units"]["laser wavelength"])
laser_frequency = (2 * np.pi / lambda_0 * ureg.c).to("rad/s")
laser_period = (1 / laser_frequency).to("fs")

def calc_norms(cfg: Dict):
"""
e_laser = np.sqrt(2.0 * _Q(cfg["drivers"]["E0"]["intensity"]) / ureg.c / ureg.epsilon_0).to("V/m")
e_norm = (ureg.m_e * (np.sqrt(2.0 * Te / ureg.m_e).to("m/s")) * laser_frequency / ureg.e).to("V/m")

:type cfg: object
"""
cfg["norms"] = {}
cfg["norms"]["n0"] = float(cfg["plasma"]["density"])
cfg["norms"]["T0"] = cfg["plasma"]["temperature"]
cfg["norms"]["frequency"] = np.sqrt(cfg["norms"]["n0"] * constants.e**2.0 / constants.m_e / constants.epsilon_0)
cfg["norms"]["velocity"] = (
2.0
* np.sqrt(
np.average(cfg["plasma"]["temperature"])
/ (1000.0 * constants.physical_constants["electron mass energy equivalent in MeV"][0])
)
* constants.c
)
cfg["norms"]["c"] = constants.c / cfg["norms"]["velocity"]
cfg["units"]["derived"]["electric field"] = e_norm
cfg["units"]["derived"]["laser field"] = e_laser

cfg["norms"]["space"] = cfg["norms"]["velocity"] / cfg["norms"]["frequency"]
cfg["norms"]["time"] = 1.0 / cfg["norms"]["frequency"]
cfg["drivers"]["E0"]["w0"] = (laser_frequency / wp0).magnitude
cfg["drivers"]["E0"]["a0"] = (e_laser / e_norm).magnitude
cfg["drivers"]["E0"]["k0"] = np.sqrt(
(cfg["drivers"]["E0"]["w0"] ** 2.0 - cfg["plasma"]["wp0"] ** 2.0) / c_light.magnitude**2.0
)

cfg = calc_e0(cfg)
print("laser parameters: ")
print(f'w0 = {round(cfg["drivers"]["E0"]["w0"], 4)}')
print(f'k0 = {round(cfg["drivers"]["E0"]["k0"], 4)}')
print(f'a0 = {round(cfg["drivers"]["E0"]["a0"], 4)}')
print()

return cfg

Expand Down Expand Up @@ -319,7 +342,11 @@ def plot_kt(kfields, td):
plt.savefig(os.path.join(fld_dir, f"{k}_kx.png"), bbox_inches="tight")
plt.close()

kx = kfields.coords["kx"].data
# np.log10(np.abs(v[tslice, :, :])).T.plot(col="t", col_wrap=4)
# plt.savefig(os.path.join(fld_dir, f"{k}_kx_ky.png"), bbox_inches="tight")
# plt.close()
#
# kx = kfields.coords["kx"].data


def post_process(result, cfg: Dict, td: str) -> Tuple[xr.Dataset, xr.Dataset]:
Expand All @@ -334,10 +361,7 @@ def post_process(result, cfg: Dict, td: str) -> Tuple[xr.Dataset, xr.Dataset]:

def make_xarrays(cfg, this_t, state, td):
phi_vs_t = state["phi"].view(np.complex128)
phi_k = xr.DataArray(
phi_vs_t,
coords=(("t", this_t), ("kx", cfg["grid"]["kx"]), ("ky", cfg["grid"]["ky"])),
)
phi_k = xr.DataArray(phi_vs_t, coords=(("t", this_t), ("kx", cfg["grid"]["kx"]), ("ky", cfg["grid"]["ky"])))

ex_k = xr.DataArray(
-1j * cfg["grid"]["kx"][:, None] * phi_vs_t,
Expand All @@ -362,11 +386,22 @@ def make_xarrays(cfg, this_t, state, td):
-np.fft.ifft2(1j * cfg["grid"]["ky"][None, :] * phi_vs_t) / cfg["grid"]["nx"] / cfg["grid"]["ny"] * 4,
coords=(("t", this_t), ("x", cfg["grid"]["x"]), ("y", cfg["grid"]["y"])),
)

e0x = xr.DataArray(
state["e0"].view(np.complex128)[..., 0],
coords=(("t", this_t), ("x", cfg["grid"]["x"]), ("y", cfg["grid"]["y"])),
)

e0y = xr.DataArray(
state["e0"].view(np.complex128)[..., 1],
coords=(("t", this_t), ("x", cfg["grid"]["x"]), ("y", cfg["grid"]["y"])),
)

delta = xr.DataArray(state["delta"], coords=(("t", this_t), ("x", cfg["grid"]["x"]), ("y", cfg["grid"]["y"])))

kfields = xr.Dataset({"phi": phi_k, "ex": ex_k, "ey": ey_k})
fields = xr.Dataset({"phi": phi_x, "ex": ex, "ey": ey, "delta": delta})
kfields.to_netcdf(os.path.join(td, "binary", "k-fields.xr"), engine="h5netcdf", invalid_netcdf=True)
fields = xr.Dataset({"phi": phi_x, "ex": ex, "ey": ey, "delta": delta, "e0_x": e0x, "e0_y": e0y})
# kfields.to_netcdf(os.path.join(td, "binary", "k-fields.xr"), engine="h5netcdf", invalid_netcdf=True)
fields.to_netcdf(os.path.join(td, "binary", "fields.xr"), engine="h5netcdf", invalid_netcdf=True)

return kfields, fields
Expand Down
2 changes: 1 addition & 1 deletion adept/lpse2d/train_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def remote_run(run_id, t_or_v):
actual_ek1 = xr.open_dataarray(
misc.download_file("ground_truth.nc", artifact_uri=mlflow_run.info.artifact_uri, destination_path=td)
)
mod_defaults["grid"] = helpers.get_derived_quantities(mod_defaults["grid"])
mod_defaults = helpers.get_derived_quantities(mod_defaults)
misc.log_params(mod_defaults)
mod_defaults["grid"] = helpers.get_solver_quantities(mod_defaults["grid"])
mod_defaults = helpers.get_save_quantities(mod_defaults)
Expand Down
7 changes: 6 additions & 1 deletion adept/tf1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def post_process(result, cfg: Dict, td: str) -> Dict:
return datasets


def get_derived_quantities(cfg_grid: Dict) -> Dict:
def get_derived_quantities(cfg: Dict) -> Dict:
"""
This function just updates the config with the derived quantities that are only integers or strings.
Expand All @@ -99,6 +99,9 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
:param cfg_grid:
:return:
"""

cfg_grid = cfg["grid"]

cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]
cfg_grid["dt"] = 0.05 * cfg_grid["dx"]
cfg_grid["nt"] = int(cfg_grid["tmax"] / cfg_grid["dt"] + 1)
Expand All @@ -110,6 +113,8 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
else:
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

cfg["grid"] = cfg_grid

return cfg_grid


Expand Down
2 changes: 1 addition & 1 deletion adept/tf1d/train_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def train_loop():
locs = {"$k_0$": k0, "$a_0$": a0, r"$\nu_{ee}$": nuee}
actual_nk1 = xr.DataArray(fks["n-(k_x)"].loc[locs].data[:, 1], coords=(("t", fks.coords["t"].data),))
with mlflow.start_run(run_name=f"{epoch=}-{sim=}", nested=True) as mlflow_run:
mod_defaults["grid"] = helpers.get_derived_quantities(mod_defaults["grid"])
mod_defaults = helpers.get_derived_quantities(mod_defaults)
misc.log_params(mod_defaults)

mod_defaults["grid"] = helpers.get_solver_quantities(mod_defaults["grid"])
Expand Down
8 changes: 6 additions & 2 deletions adept/vlasov1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _initialize_total_distribution_(cfg, cfg_grid):
return n_prof_total, f


def get_derived_quantities(cfg_grid: Dict) -> Dict:
def get_derived_quantities(cfg: Dict) -> Dict:
"""
This function just updates the config with the derived quantities that are only integers or strings.
Expand All @@ -149,6 +149,8 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
:param cfg_grid:
:return:
"""
cfg_grid = cfg["grid"]

cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]
cfg_grid["dv"] = 2.0 * cfg_grid["vmax"] / cfg_grid["nv"]

Expand All @@ -162,7 +164,9 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
else:
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

return cfg_grid
cfg["grid"] = cfg_grid

return cfg


def get_solver_quantities(cfg: Dict) -> Dict:
Expand Down
6 changes: 5 additions & 1 deletion adept/vlasov2d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _initialize_total_distribution_(cfg, cfg_grid):
return n_prof_total, f


def get_derived_quantities(cfg_grid: Dict) -> Dict:
def get_derived_quantities(cfg: Dict) -> Dict:
"""
This function just updates the config with the derived quantities that are only integers or strings.
Expand All @@ -153,6 +153,8 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
:param cfg_grid:
:return:
"""
cfg_grid = cfg["grid"]

cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]
cfg_grid["dy"] = cfg_grid["ymax"] / cfg_grid["ny"]
cfg_grid["dvx"] = 2.0 * cfg_grid["vmax"] / cfg_grid["nvx"]
Expand All @@ -168,6 +170,8 @@ def get_derived_quantities(cfg_grid: Dict) -> Dict:
else:
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

cfg["grid"] = cfg_grid

return cfg_grid


Expand Down
Loading

0 comments on commit 39231b9

Please sign in to comment.