diff --git a/adept/lpse2d/helpers.py b/adept/lpse2d/helpers.py index 64910c0..ca5cf8e 100644 --- a/adept/lpse2d/helpers.py +++ b/adept/lpse2d/helpers.py @@ -14,6 +14,7 @@ from plasmapy.formulary.collisions.frequencies import fundamental_electron_collision_freq from adept.lpse2d.core import integrator, driver +from adept.tf1d.pushers import get_envelope def write_units(cfg, td): @@ -223,15 +224,59 @@ def init_state(cfg: Dict, td=None) -> Dict: phi = random_amps * np.exp(1j * random_phases) phi = jnp.fft.fft2(phi) - ureg = pint.UnitRegistry() - _Q = ureg.Quantity + if cfg["density"]["basis"] == "uniform": + nprof = np.ones((cfg["grid"]["nx"], cfg["grid"]["ny"])) + + elif cfg["density"]["basis"] == "linear": + left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5 + right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5 + rise = cfg["density"]["rise"] + mask = get_envelope(rise, rise, left, right, cfg["grid"]["x"]) + + ureg = pint.UnitRegistry() + _Q = ureg.Quantity + + L = ( + _Q(cfg["density"]["gradient scale length"]).to("nm").magnitude + / cfg["units"]["derived"]["x0"].to("nm").magnitude + ) + nprof = cfg["density"]["val at center"] + (cfg["grid"]["x"] - cfg["density"]["center"]) / L + nprof = mask * nprof + + elif cfg["density"]["basis"] == "exponential": + left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5 + right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5 + rise = cfg["density"]["rise"] + mask = get_envelope(rise, rise, left, right, cfg["grid"]["x"]) + + ureg = pint.UnitRegistry() + _Q = ureg.Quantity + + L = ( + _Q(cfg["density"]["gradient scale length"]).to("nm").magnitude + / cfg["units"]["derived"]["x0"].to("nm").magnitude + ) + nprof = cfg["density"]["val at center"] * np.exp((cfg["grid"]["x"] - cfg["density"]["center"]) / L) + nprof = mask * nprof + + elif cfg["density"]["basis"] == "tanh": + left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5 + right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5 + rise = cfg["density"]["rise"] + nprof = get_envelope(rise, rise, left, right, cfg["grid"]["x"]) + + if cfg["density"]["bump_or_trough"] == "trough": + nprof = 1 - nprof + nprof = cfg["density"]["baseline"] + cfg["density"]["bump_height"] * nprof + + elif cfg["density"]["basis"] == "sine": + baseline = cfg["density"]["baseline"] + amp = cfg["density"]["amplitude"] + kk = cfg["density"]["wavenumber"] + nprof = baseline * (1.0 + amp * jnp.sin(kk * cfg["grid"]["x"])) + else: + raise NotImplementedError - L = ( - _Q(cfg["density"]["gradient scale length"]).to("nm").magnitude - / cfg["units"]["derived"]["x0"].to("nm").magnitude - ) - nprof = cfg["density"]["val at center"] + (cfg["grid"]["x"] - cfg["density"]["center"]) / L - nprof = jnp.repeat(nprof[:, None], cfg["grid"]["ny"], axis=1) state = { "e0": e0, "nb": nprof, @@ -410,7 +455,7 @@ def post_process(result, cfg: Dict, td: str) -> Tuple[xr.Dataset, xr.Dataset]: kfields, fields = make_xarrays(cfg, result.ts, result.ys, td) plot_fields(fields, td) - plot_kt(kfields, td) + # plot_kt(kfields, td) return kfields, fields diff --git a/adept/lpse2d/train_damping.py b/adept/lpse2d/train_damping.py index 8d7cc53..e988ccb 100644 --- a/adept/lpse2d/train_damping.py +++ b/adept/lpse2d/train_damping.py @@ -4,7 +4,7 @@ from itertools import product import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_debug_nans", True) @@ -41,7 +41,7 @@ def _modify_defaults_(defaults, k0, nuee, a0): @python_app def remote_run(run_id, t_or_v): - from jax.config import config + from jax import config config.update("jax_enable_x64", True) diff --git a/adept/tf1d/train_damping.py b/adept/tf1d/train_damping.py index 933fe46..bbd45f3 100644 --- a/adept/tf1d/train_damping.py +++ b/adept/tf1d/train_damping.py @@ -5,7 +5,7 @@ from itertools import product import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_debug_nans", True) diff --git a/tests/test_lpse2d/configs/epw.yaml b/tests/test_lpse2d/configs/epw.yaml index 53d5255..02332e2 100644 --- a/tests/test_lpse2d/configs/epw.yaml +++ b/tests/test_lpse2d/configs/epw.yaml @@ -1,6 +1,7 @@ mode: envelope-2d density: + basis: uniform offset: 1.0 slope: 0.0 noise: diff --git a/tests/test_lpse2d/configs/resonance_search.yaml b/tests/test_lpse2d/configs/resonance_search.yaml index 7ac4487..f63e018 100644 --- a/tests/test_lpse2d/configs/resonance_search.yaml +++ b/tests/test_lpse2d/configs/resonance_search.yaml @@ -1,6 +1,7 @@ mode: envelope-2d density: + basis: uniform offset: 1.0 slope: 0.0 noise: diff --git a/tests/test_lpse2d/test_resonance.py b/tests/test_lpse2d/test_resonance.py index 4c70062..4980bd0 100644 --- a/tests/test_lpse2d/test_resonance.py +++ b/tests/test_lpse2d/test_resonance.py @@ -5,7 +5,7 @@ import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_debug_nans", True) diff --git a/tests/test_tf1d/test_against_vlasov.py b/tests/test_tf1d/test_against_vlasov.py index 7afe45d..1e73215 100644 --- a/tests/test_tf1d/test_against_vlasov.py +++ b/tests/test_tf1d/test_against_vlasov.py @@ -3,7 +3,7 @@ import yaml import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) diff --git a/tests/test_tf1d/test_landau_damping.py b/tests/test_tf1d/test_landau_damping.py index ca874ef..d3e5e35 100644 --- a/tests/test_tf1d/test_landau_damping.py +++ b/tests/test_tf1d/test_landau_damping.py @@ -3,7 +3,7 @@ import yaml import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) diff --git a/tests/test_tf1d/test_resonance.py b/tests/test_tf1d/test_resonance.py index 5717b70..58f8dcf 100644 --- a/tests/test_tf1d/test_resonance.py +++ b/tests/test_tf1d/test_resonance.py @@ -3,7 +3,7 @@ import yaml, pytest import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) diff --git a/tests/test_tf1d/test_resonance_search.py b/tests/test_tf1d/test_resonance_search.py index 208765d..6c1bc43 100644 --- a/tests/test_tf1d/test_resonance_search.py +++ b/tests/test_tf1d/test_resonance_search.py @@ -5,7 +5,7 @@ import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_debug_nans", True) diff --git a/tests/test_vlasov1d/test_absorbing_wave.py b/tests/test_vlasov1d/test_absorbing_wave.py index 7d901c8..0587ee7 100644 --- a/tests/test_vlasov1d/test_absorbing_wave.py +++ b/tests/test_vlasov1d/test_absorbing_wave.py @@ -1,5 +1,5 @@ import numpy as np -from jax.config import config +from jax import config import equinox as eqx config.update("jax_enable_x64", True) diff --git a/tests/test_vlasov1d/test_landau_damping.py b/tests/test_vlasov1d/test_landau_damping.py index caefa06..edec95d 100644 --- a/tests/test_vlasov1d/test_landau_damping.py +++ b/tests/test_vlasov1d/test_landau_damping.py @@ -5,7 +5,7 @@ import yaml, pytest import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True) diff --git a/tests/test_vlasov2d/test_landau_damping.py b/tests/test_vlasov2d/test_landau_damping.py index d1fa076..876bc60 100644 --- a/tests/test_vlasov2d/test_landau_damping.py +++ b/tests/test_vlasov2d/test_landau_damping.py @@ -5,7 +5,7 @@ import yaml, pytest import numpy as np -from jax.config import config +from jax import config config.update("jax_enable_x64", True) # config.update("jax_disable_jit", True)