Skip to content

Commit

Permalink
VFP hotspot example (#57)
Browse files Browse the repository at this point in the history
* VFP hotspot example and units in input

* updating input decks

* removing unused code
  • Loading branch information
joglekara authored Jul 26, 2024
1 parent 92cb191 commit ddb2606
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 179 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ jobs:
python -m pip install --upgrade pip
python -m pip install --upgrade black
black --line-length 120 --check adept
black --line-length 120 --check utils
python -m pip install --upgrade pytest wheel
python -m pip install --upgrade -r requirements-cpu.txt
Expand Down
31 changes: 15 additions & 16 deletions adept/vfp1d/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from astropy import constants as csts, units as u
from astropy.units import Quantity as _Q
from diffrax import diffeqsolve, SaveAt, ODETerm, SubSaveAt
from jax import numpy as jnp, tree_util as jtu

Expand Down Expand Up @@ -36,13 +37,6 @@ def write_units(self) -> Dict:

beta = vth / csts.c

box_length = ((self.cfg["grid"]["xmax"] - self.cfg["grid"]["xmin"]) * x0).to("micron")
if "ymax" in self.cfg["grid"].keys():
box_width = ((self.cfg["grid"]["ymax"] - self.cfg["grid"]["ymin"]) * x0).to("micron")
else:
box_width = "inf"
sim_duration = (self.cfg["grid"]["tmax"] * tp0).to("ps")

logLambda_ei, logLambda_ee = calc_logLambda(self.cfg, ne, Te, Z, ion_species)
logLambda_ee = logLambda_ei

Expand Down Expand Up @@ -88,15 +82,15 @@ def write_units(self) -> Dict:
"lambda_mfp_epphaines": (vth / nuei_epphaines).to("micron"),
"nD_NRL": nD_NRL,
"nD_Shkarofsky": nD_Shkarofsky,
"box_length": box_length,
"box_width": box_width,
"sim_duration": sim_duration,
# "box_length": box_length,
# "box_width": box_width,
# "sim_duration": sim_duration,
}

self.cfg["units"]["derived"] = all_quantities
self.cfg["grid"]["beta"] = beta.value

return all_quantities
return {k: str(v) for k, v in all_quantities.items()}

def get_derived_quantities(self):
"""
Expand All @@ -108,20 +102,21 @@ def get_derived_quantities(self):
:return:
"""
cfg_grid = self.cfg["grid"]
# cfg_grid["xmax"] = u.Quantity(cfg_grid["xmax"]

cfg_grid["xmax"] = (_Q(cfg_grid["xmax"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value
cfg_grid["xmin"] = (_Q(cfg_grid["xmin"]) / _Q(self.cfg["units"]["derived"]["x0"])).to("").value
cfg_grid["dx"] = cfg_grid["xmax"] / cfg_grid["nx"]

# sqrt(2 * k * T / m)
cfg_grid["vmax"] = (
8
* np.sqrt(
(u.Quantity(self.cfg["units"]["reference electron temperature"]) / (csts.m_e * csts.c**2.0)).to("")
).value
* np.sqrt((_Q(self.cfg["units"]["reference electron temperature"]) / (csts.m_e * csts.c**2.0)).to("")).value
)

cfg_grid["dv"] = cfg_grid["vmax"] / cfg_grid["nv"]

cfg_grid["tmax"] = (_Q(cfg_grid["tmax"]) / self.cfg["units"]["derived"]["tp0"]).to("").value
cfg_grid["dt"] = (_Q(cfg_grid["dt"]) / self.cfg["units"]["derived"]["tp0"]).to("").value

cfg_grid["nt"] = int(cfg_grid["tmax"] / cfg_grid["dt"]) + 1

if cfg_grid["nt"] > 1e6:
Expand All @@ -131,6 +126,10 @@ def get_derived_quantities(self):
cfg_grid["max_steps"] = cfg_grid["nt"] + 4

cfg_grid["tmax"] = cfg_grid["dt"] * cfg_grid["nt"]

print("tmax", cfg_grid["tmax"], "dt", cfg_grid["dt"])
print("xmax", cfg_grid["xmax"], "dx", cfg_grid["dx"])

self.cfg["grid"] = cfg_grid

def get_solver_quantities(self):
Expand Down
108 changes: 11 additions & 97 deletions adept/vfp1d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import numpy as np
from jax import Array
import xarray, yaml, plasmapy
import xarray, yaml
from astropy import units as u, constants as csts
from astropy.units import Quantity as _Q
from jax import numpy as jnp
from adept import get_envelope

Expand Down Expand Up @@ -38,93 +39,6 @@ def gamma_5_over_m(m: float) -> Array:
return np.interp(m, m_ax, g_5_m)


def write_units(cfg: Dict, td: str) -> Dict:
"""
This function writes the units to a file and updates the config with the derived quantities
It is a REQUIRED function for the exoskeleton
:param cfg: Dict
:param td: str
:return: Dict
"""

ne = u.Quantity(cfg["units"]["reference electron density"]).to("1/cm^3")
ni = ne / cfg["units"]["Z"]
Te = u.Quantity(cfg["units"]["reference electron temperature"]).to("eV")
Ti = u.Quantity(cfg["units"]["reference ion temperature"]).to("eV")
Z = cfg["units"]["Z"]
n0 = u.Quantity("9.0663e21/cm^3")
ion_species = cfg["units"]["Ion"]

wp0 = np.sqrt(n0 * csts.e.to("C") ** 2.0 / (csts.m_e * csts.eps0)).to("Hz")
tp0 = (1 / wp0).to("fs")

vth = np.sqrt(2 * Te / csts.m_e).to("m/s") # mean square velocity eq 4-51a in Shkarofsky

x0 = (csts.c / wp0).to("nm")

beta = vth / csts.c

box_length = ((cfg["grid"]["xmax"] - cfg["grid"]["xmin"]) * x0).to("micron")
if "ymax" in cfg["grid"].keys():
box_width = ((cfg["grid"]["ymax"] - cfg["grid"]["ymin"]) * x0).to("micron")
else:
box_width = "inf"
sim_duration = (cfg["grid"]["tmax"] * tp0).to("ps")

logLambda_ei, logLambda_ee = calc_logLambda(cfg, ne, Te, Z, ion_species)
logLambda_ee = logLambda_ei

nD_NRL = 1.72e9 * Te.value**1.5 / np.sqrt(ne.value)
nD_Shkarofsky = np.exp(logLambda_ei) * Z / 9

nuei_shk = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / np.exp(logLambda_ei)
nuei_nrl = np.sqrt(2.0 / np.pi) * wp0 * logLambda_ei / nD_NRL

lambda_mfp_shk = (vth / nuei_shk).to("micron")
lambda_mfp_nrl = (vth / nuei_nrl).to("micron")

nuei_epphaines = (
1 / (0.75 * np.sqrt(csts.m_e) * Te**1.5 / (np.sqrt(2 * np.pi) * ni * Z**2.0 * csts.e.gauss**4.0 * logLambda_ei))
).to("Hz")

all_quantities = {
"wp0": wp0,
"n0": n0,
"tp0": tp0,
"ne": ne,
"vth": vth,
"Te": Te,
"Ti": Ti,
"logLambda_ei": logLambda_ei,
"logLambda_ee": logLambda_ee,
"beta": beta,
"x0": x0,
"nuei_shk": nuei_shk,
"nuei_nrl": nuei_nrl,
"nuei_epphaines": nuei_epphaines,
"nuei_shk_norm": nuei_shk / wp0,
"nuei_nrl_norm": nuei_nrl / wp0,
"nuei_epphaines_norm": nuei_epphaines / wp0,
"lambda_mfp_shk": lambda_mfp_shk,
"lambda_mfp_nrl": lambda_mfp_nrl,
"lambda_mfp_epphaines": (vth / nuei_epphaines).to("micron"),
"nD_NRL": nD_NRL,
"nD_Shkarofsky": nD_Shkarofsky,
"box_length": box_length,
"box_width": box_width,
"sim_duration": sim_duration,
}

cfg["units"]["derived"] = all_quantities
cfg["grid"]["beta"] = beta.value

with open(os.path.join(td, "units.yaml"), "w") as fi:
yaml.dump({k: str(v) for k, v in all_quantities.items()}, fi)

return cfg


def calc_logLambda(cfg: Dict, ne: float, Te: float, Z: int, ion_species: str) -> Tuple[float, float]:
"""
Calculate the Coulomb logarithm
Expand All @@ -139,11 +53,7 @@ def calc_logLambda(cfg: Dict, ne: float, Te: float, Z: int, ion_species: str) ->
"""
if isinstance(cfg["units"]["logLambda"], str):
if cfg["units"]["logLambda"].casefold() == "plasmapy":
logLambda_ei = plasmapy.formulary.Coulomb_logarithm(n_e=ne, T=Te, z_mean=Z, species=("e", ion_species))
logLambda_ee = plasmapy.formulary.Coulomb_logarithm(n_e=ne, T=Te, z_mean=1.0, species=("e", "e"))

elif cfg["units"]["logLambda"].casefold() == "nrl":
if cfg["units"]["logLambda"].casefold() == "nrl":
log_ne = np.log(ne.to("1/cm^3").value)
log_Te = np.log(Te.to("eV").value)
log_Z = np.log(Z)
Expand Down Expand Up @@ -257,9 +167,13 @@ def _initialize_total_distribution_(cfg, cfg_grid):
profs[k] = np.ones_like(prof_total[k])

elif species_params[k]["basis"] == "tanh":
left = species_params[k]["center"] - species_params[k]["width"] * 0.5
right = species_params[k]["center"] + species_params[k]["width"] * 0.5
rise = species_params[k]["rise"]
center = (_Q(species_params[k]["center"]) / cfg["units"]["derived"]["x0"]).to("").value
width = (_Q(species_params[k]["width"]) / cfg["units"]["derived"]["x0"]).to("").value
rise = (_Q(species_params[k]["rise"]) / cfg["units"]["derived"]["x0"]).to("").value

left = center - width * 0.5
right = center + width * 0.5
# rise = species_params[k]["rise"]
prof = get_envelope(rise, rise, left, right, cfg_grid["x"])

if species_params[k]["bump_or_trough"] == "trough":
Expand All @@ -269,7 +183,7 @@ def _initialize_total_distribution_(cfg, cfg_grid):
elif species_params[k]["basis"] == "sine":
baseline = species_params[k]["baseline"]
amp = species_params[k]["amplitude"]
ll = species_params[k]["wavelength"]
ll = (_Q(species_params[k]["wavelength"]) / cfg["units"]["derived"]["x0"]).to("").value

profs[k] = baseline * (1.0 + amp * jnp.sin(2 * jnp.pi / ll * cfg_grid["x"]))

Expand Down
20 changes: 6 additions & 14 deletions adept/vfp1d/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
import numpy as np
import xarray as xr
from time import time
from astropy.units import Quantity as _Q


def calc_EH(this_Z: int, this_wt: float) -> float:
Expand Down Expand Up @@ -186,7 +186,7 @@ def store_f(cfg: Dict, this_t: Dict, td: str, ys: Dict) -> xr.Dataset:
{
dist: xr.DataArray(
ys["electron"][dist],
coords=(("t (ps)", this_t["electron"]), ("x (um)", xax), ("v (c)", cfg["grid"]["v"])),
coords=(("t (ps)", tax), ("x (um)", xax), ("v (c)", cfg["grid"]["v"])),
)
for dist in ys["electron"].keys()
}
Expand Down Expand Up @@ -342,20 +342,12 @@ def get_save_quantities(cfg: Dict) -> Dict:
:param cfg:
:return: The updated config
"""

for k in cfg["save"].keys(): # this can be fields or electron or scalar?
for k2 in cfg["save"][k].keys(): # this can be t, x, y, kx, ky (eventually)
if k2 == "x":
dx = (cfg["save"][k][k2][f"{k2}max"] - cfg["save"][k][k2][f"{k2}min"]) / cfg["save"][k][k2][f"n{k2}"]
cfg["save"][k][k2]["ax"] = np.linspace(
cfg["save"][k][k2][f"{k2}min"] + dx / 2.0,
cfg["save"][k][k2][f"{k2}max"] - dx / 2.0,
cfg["save"][k][k2][f"n{k2}"],
)

else:
cfg["save"][k][k2]["ax"] = np.linspace(
cfg["save"][k][k2][f"{k2}min"], cfg["save"][k][k2][f"{k2}max"], cfg["save"][k][k2][f"n{k2}"]
)
tmin = (_Q(cfg["save"][k]["t"]["tmin"]) / cfg["units"]["derived"]["tp0"]).to("").value
tmax = (_Q(cfg["save"][k]["t"]["tmax"]) / cfg["units"]["derived"]["tp0"]).to("").value
cfg["save"][k]["t"]["ax"] = jnp.linspace(tmin, tmax, cfg["save"][k]["t"]["nt"])

if k.startswith("fields"):
cfg["save"][k]["func"] = get_field_save_func(cfg, k)
Expand Down
39 changes: 20 additions & 19 deletions configs/vfp-1d/epp-short.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
units:
laser wavelength: 351nm
reference electron temperature: 1000eV
reference ion temperature: 500eV
reference electron density: 5e20/cm^3
reference electron temperature: 300eV
reference ion temperature: 300eV
reference electron density: 1.5e21/cm^3
Z: 30
Ion: Au+
logLambda: nrl
Expand All @@ -21,42 +21,43 @@ density:
basis: uniform
baseline: 1.0
bump_or_trough: bump
center: 0.0
rise: 25.0
center: 0m
rise: 1m
bump_height: 0.0
width: 100000.0
width: 1m
T:
basis: sine
baseline: 1.0
amplitude: 1.0e-3
wavelength: 50000.0
wavelength: 250um

grid:
dt: 100.0
nv: 512
dt: 1fs
nv: 256
nx: 32
tmin: 0.
tmax: 50000.0
xmax: 100000.0
xmin: 0.0
tmin: 0.ps
tmax: 2ps
vmax: 8.0
xmax: 500um
xmin: 0.0um
nl: 1

save:
fields:
t:
tmin: 0.0
tmax: 50000.0
tmin: 0.0ps
tmax: 0.5ps
nt: 11
electron:
t:
tmin: 0.0
tmax: 50000.0
tmin: 0.0ps
tmax: 0.5ps
nt: 6

solver: vfp-2d
solver: vfp-1d

mlflow:
experiment: vfp2d
experiment: vfp1d
run: epperlein-short

drivers:
Expand Down
Loading

0 comments on commit ddb2606

Please sign in to comment.