Skip to content

Commit

Permalink
LPSE Regression and JAX config import update (#36)
Browse files Browse the repository at this point in the history
* Fixed LPSE regression

* updating config import per JAX update
  • Loading branch information
joglekara authored Mar 10, 2024
1 parent a811c41 commit 2947c63
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 20 deletions.
63 changes: 54 additions & 9 deletions adept/lpse2d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from plasmapy.formulary.collisions.frequencies import fundamental_electron_collision_freq

from adept.lpse2d.core import integrator, driver
from adept.tf1d.pushers import get_envelope


def write_units(cfg, td):
Expand Down Expand Up @@ -223,15 +224,59 @@ def init_state(cfg: Dict, td=None) -> Dict:
phi = random_amps * np.exp(1j * random_phases)
phi = jnp.fft.fft2(phi)

ureg = pint.UnitRegistry()
_Q = ureg.Quantity
if cfg["density"]["basis"] == "uniform":
nprof = np.ones((cfg["grid"]["nx"], cfg["grid"]["ny"]))

elif cfg["density"]["basis"] == "linear":
left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5
right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5
rise = cfg["density"]["rise"]
mask = get_envelope(rise, rise, left, right, cfg["grid"]["x"])

ureg = pint.UnitRegistry()
_Q = ureg.Quantity

L = (
_Q(cfg["density"]["gradient scale length"]).to("nm").magnitude
/ cfg["units"]["derived"]["x0"].to("nm").magnitude
)
nprof = cfg["density"]["val at center"] + (cfg["grid"]["x"] - cfg["density"]["center"]) / L
nprof = mask * nprof

elif cfg["density"]["basis"] == "exponential":
left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5
right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5
rise = cfg["density"]["rise"]
mask = get_envelope(rise, rise, left, right, cfg["grid"]["x"])

ureg = pint.UnitRegistry()
_Q = ureg.Quantity

L = (
_Q(cfg["density"]["gradient scale length"]).to("nm").magnitude
/ cfg["units"]["derived"]["x0"].to("nm").magnitude
)
nprof = cfg["density"]["val at center"] * np.exp((cfg["grid"]["x"] - cfg["density"]["center"]) / L)
nprof = mask * nprof

elif cfg["density"]["basis"] == "tanh":
left = cfg["density"]["center"] - cfg["density"]["width"] * 0.5
right = cfg["density"]["center"] + cfg["density"]["width"] * 0.5
rise = cfg["density"]["rise"]
nprof = get_envelope(rise, rise, left, right, cfg["grid"]["x"])

if cfg["density"]["bump_or_trough"] == "trough":
nprof = 1 - nprof
nprof = cfg["density"]["baseline"] + cfg["density"]["bump_height"] * nprof

elif cfg["density"]["basis"] == "sine":
baseline = cfg["density"]["baseline"]
amp = cfg["density"]["amplitude"]
kk = cfg["density"]["wavenumber"]
nprof = baseline * (1.0 + amp * jnp.sin(kk * cfg["grid"]["x"]))
else:
raise NotImplementedError

L = (
_Q(cfg["density"]["gradient scale length"]).to("nm").magnitude
/ cfg["units"]["derived"]["x0"].to("nm").magnitude
)
nprof = cfg["density"]["val at center"] + (cfg["grid"]["x"] - cfg["density"]["center"]) / L
nprof = jnp.repeat(nprof[:, None], cfg["grid"]["ny"], axis=1)
state = {
"e0": e0,
"nb": nprof,
Expand Down Expand Up @@ -410,7 +455,7 @@ def post_process(result, cfg: Dict, td: str) -> Tuple[xr.Dataset, xr.Dataset]:
kfields, fields = make_xarrays(cfg, result.ts, result.ys, td)

plot_fields(fields, td)
plot_kt(kfields, td)
# plot_kt(kfields, td)

return kfields, fields

Expand Down
4 changes: 2 additions & 2 deletions adept/lpse2d/train_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from itertools import product

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
Expand Down Expand Up @@ -41,7 +41,7 @@ def _modify_defaults_(defaults, k0, nuee, a0):

@python_app
def remote_run(run_id, t_or_v):
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)

Expand Down
2 changes: 1 addition & 1 deletion adept/tf1d/train_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from itertools import product

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
Expand Down
1 change: 1 addition & 0 deletions tests/test_lpse2d/configs/epw.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mode: envelope-2d

density:
basis: uniform
offset: 1.0
slope: 0.0
noise:
Expand Down
1 change: 1 addition & 0 deletions tests/test_lpse2d/configs/resonance_search.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mode: envelope-2d

density:
basis: uniform
offset: 1.0
slope: 0.0
noise:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lpse2d/test_resonance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tf1d/test_against_vlasov.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tf1d/test_landau_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tf1d/test_resonance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import yaml, pytest

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tf1d/test_resonance_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_debug_nans", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vlasov1d/test_absorbing_wave.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from jax.config import config
from jax import config
import equinox as eqx

config.update("jax_enable_x64", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vlasov1d/test_landau_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml, pytest

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vlasov2d/test_landau_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import yaml, pytest

import numpy as np
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
# config.update("jax_disable_jit", True)
Expand Down

0 comments on commit 2947c63

Please sign in to comment.