Skip to content

Commit

Permalink
Two plasmon decay using the LPSE equations (#41)
Browse files Browse the repository at this point in the history
* divergence fix

* density gradients

* noise init

* added kx ky diag

* refactor

* new units and solver based on matlab code

* units

* working TPD

* tpd sweep

* added bandwidth

* learn code

* Working optimization

* parent runs

* trainable NN

* cleanup

* refactor

* New VAE and args get postprocessed

* scale length

* save func

* threshold

* running with interpolated save func

* xarrays

* generalized for opt and learn and fwd

fwd lpse2d works

* xmax and save xmax auto

* opt and sweep updates

* working opt

* working optimization

* update

* intensity bug

* e field damping at boundaries

* sum amplitudes

* zero mask and low pass filter

* boundaries

* working collisions

* passing vlasov1d and twofluid1d tests

* passing tests

* passing tests
  • Loading branch information
joglekara authored Jun 26, 2024
1 parent d6399b7 commit df1510a
Show file tree
Hide file tree
Showing 80 changed files with 2,365 additions and 1,795 deletions.
329 changes: 275 additions & 54 deletions adept/lpse2d/core/epw.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,17 @@
from typing import Dict, Tuple
from functools import partial

import diffrax
import jax
from jax import numpy as jnp
import equinox as eqx
import numpy as np
from theory import electrostatic
from adept.theory import electrostatic
from adept.lpse2d.core.driver import Driver
from adept.lpse2d.core.trapper import ParticleTrapper


class ParticleTrapper(eqx.Module):
kx: np.ndarray
kax_sq: jax.Array

model_kld: float
wis: jax.Array
norm_kld: jnp.float64
norm_nuee: jnp.float64
vph: jnp.float64
fft_norm: float
dx: float

def __init__(self, cfg, species="electron"):
self.kx = cfg["grid"]["kx"]
self.dx = cfg["grid"]["dx"]
self.kax_sq = cfg["grid"]["kx"][:, None] ** 2 + cfg["grid"]["ky"][None, :] ** 2
table_wrs, table_wis, table_klds = electrostatic.get_complex_frequency_table(
1024, cfg["terms"]["epw"]["kinetic real part"]
)
all_ks = jnp.sqrt(self.kax_sq).flatten()
self.model_kld = cfg["terms"]["epw"]["trapping"]["kld"]
self.wis = jnp.interp(all_ks, table_klds, table_wis, left=0.0, right=0.0).reshape(self.kax_sq.shape)
self.norm_kld = (self.model_kld - 0.26) / 0.14
self.norm_nuee = (jnp.log10(1.0e-7) + 7.0) / -4.0

this_wr = jnp.interp(self.model_kld, table_klds, table_wrs, left=1.0, right=table_wrs[-1])
self.vph = this_wr / self.model_kld
self.fft_norm = cfg["grid"]["nx"] * cfg["grid"]["ny"] / 4.0
# Make models
# if models is not None:
# self.nu_g_model = models["nu_g"]
# else:
# self.nu_g_model = lambda x: -32

def __call__(self, t, delta, args):
e = args["eh"]
ek = jnp.fft.fft2(e[..., 0]) / self.fft_norm

# this is where a specific k is chosen for the growth rate and where the identity of this delta object is given
chosen_ek = jnp.interp(self.model_kld, self.kx, jnp.mean(jnp.abs(ek), axis=1))
norm_e = (jnp.log10(chosen_ek + 1e-10) + 10.0) / -10.0
func_inputs = jnp.stack([norm_e, self.norm_kld, self.norm_nuee], axis=-1)
growth_rates = 10 ** jnp.squeeze(3 * args["nu_g"](func_inputs))

return -self.vph * jnp.gradient(jnp.pad(delta, pad_width=1, mode="wrap"), axis=0)[
1:-1, 1:-1
] / self.dx + growth_rates * jnp.abs(jnp.fft.ifft2(ek * self.fft_norm * self.wis)) / (1.0 + delta**2.0)


class EPW2D(eqx.Module):
class SpectralPotential_old(eqx.Module):
wp0: float
ld_rates: jax.Array
n0: float
Expand Down Expand Up @@ -138,9 +91,9 @@ def calc_density_pert_step(self, nb: jax.Array, dn: jax.Array) -> jax.Array:
return -1j / 2.0 * self.wp0 * (1.0 - nb / self.n0 - dn / self.n0)

def _calc_div_(self, arr):
arrk = jnp.fft.fft2(arr)
arrk = jnp.fft.fft2(arr, axes=[0, 1])
divk = self.kx[:, None] * arrk[..., 0] + self.ky[None, :] * arrk[..., 1]
return jnp.fft.ifft2(divk)
return jnp.fft.ifft2(divk, axes=[0, 1])

def calc_tpd_source_step(self, phi: jax.Array, e0: jax.Array, nb: jax.Array, t: float) -> jax.Array:
"""
Expand Down Expand Up @@ -197,7 +150,7 @@ def update_delta(self, t, y, args):
# return y["delta"] + self.dt * self.trapper(t, y["delta"], args)
return diffrax.diffeqsolve(
terms=diffrax.ODETerm(self.trapper),
solver=diffrax.Dopri8(),
solver=diffrax.Tsit5(),
t0=t,
t1=t + self.dt,
max_steps=self.num_substeps,
Expand Down Expand Up @@ -233,3 +186,271 @@ def __call__(self, t, y, args):
)

return y


class SpectralPotential:
def __init__(self, cfg) -> None:

self.kx = cfg["grid"]["kx"]
self.ky = cfg["grid"]["ky"]
self.k_sq = self.kx[:, None] ** 2 + self.ky[None, :] ** 2
self.wp0 = cfg["units"]["derived"]["wp0"]
self.e = cfg["units"]["derived"]["e"]
self.me = cfg["units"]["derived"]["me"]
self.w0 = cfg["units"]["derived"]["w0"]
self.envelope_density = cfg["units"]["envelope density"]
self.one_over_ksq = cfg["grid"]["one_over_ksq"]
self.boundary_envelope = cfg["grid"]["absorbing_boundaries"]
self.dt = cfg["grid"]["dt"]
self.cfg = cfg
self.amp_key, self.phase_key = jax.random.split(jax.random.PRNGKey(np.random.randint(2**20)), 2)
self.low_pass_filter = np.where(np.sqrt(self.k_sq) < 2.0 / 3.0 * np.amax(self.kx), 1, 0)
zero_mask = cfg["grid"]["zero_mask"]
self.low_pass_filter = self.low_pass_filter * zero_mask
self.nx = cfg["grid"]["nx"]
self.ny = cfg["grid"]["ny"]
# self.step_tpd = partial(
# diffrax.diffeqsolve,
# terms=diffrax.ODETerm(self.tpd),
# solver=diffrax.Tsit5(),
# t0=0.0,
# t1=self.dt,
# dt0=self.dt,
# )
self.tpd_const = 1j * self.e / (8 * self.wp0 * self.me)

def calc_fields_from_phi(self, phi):
"""
Calculates ex(x, y) and ey(x, y) from phi.
checked
Args:
phi (jnp.array): phi(x, y)
Returns:
ex_ey (jnp.array): Vector field e(x, y, dir)
"""
phi_k = jnp.fft.fft2(phi)
phi_k *= self.low_pass_filter

ex_k = self.kx[:, None] * phi_k * self.low_pass_filter
ey_k = self.ky[None, :] * phi_k * self.low_pass_filter
return -1j * jnp.fft.ifft2(ex_k), -1j * jnp.fft.ifft2(ey_k)

def calc_phi_from_fields(self, ex, ey):
"""
checked
"""
ex_k = jnp.fft.fft2(ex)
ey_k = jnp.fft.fft2(ey)
divE_k = 1j * (self.kx[:, None] * ex_k + self.ky[None, :] * ey_k)

phi_k = divE_k * self.one_over_ksq
phi = jnp.fft.ifft2(phi_k * self.low_pass_filter)

return phi

def tpd(self, t, y, args):
"""
checked
"""
E0 = args["E0"] # .view(jnp.complex128)
phi = y # .view(jnp.complex128)
_, ey = self.calc_fields_from_phi(phi)

tpd1 = E0[..., 1] * jnp.conj(ey)
tpd1 = jnp.fft.ifft2(jnp.fft.fft2(tpd1) * self.low_pass_filter)
# tpd1 = E0_Ey

divE_true = jnp.fft.ifft2(self.k_sq * jnp.fft.fft2(phi))
E0_divE_k = jnp.fft.fft2(E0[..., 1] * jnp.conj(divE_true))
tpd2 = 1j * self.ky[None, :] * self.one_over_ksq * E0_divE_k
tpd2 = jnp.fft.ifft2(tpd2 * self.low_pass_filter)

total_tpd = self.tpd_const * jnp.exp(-1j * (self.w0 - 2 * self.wp0) * t) * (tpd1 + tpd2)

dphi = total_tpd # .view(jnp.float64)

return dphi

def calc_tpd1(self, t, y, args):
E0 = args["E0"]
phi = y

_, ey = self.calc_fields_from_phi(phi)

tpd1 = E0[..., 1] * jnp.conj(ey)
return self.tpd_const * tpd1

def calc_tpd2(self, t, y, args):
phi = y
E0 = args["E0"]

divE_true = jnp.fft.ifft2(self.k_sq * jnp.fft.fft2(phi))
E0_divE_k = jnp.fft.fft2(E0[..., 1] * jnp.conj(divE_true))

tpd2 = 1j * self.ky[None, :] * self.one_over_ksq * E0_divE_k
tpd2 = jnp.fft.ifft2(tpd2)
return self.tpd_const * tpd2

def get_noise(self):
random_amps = 1000.0 # jax.random.uniform(self.amp_key, (self.nx, self.ny))
random_phases = 2 * np.pi * jax.random.uniform(self.phase_key, (self.nx, self.ny))
return jnp.fft.ifft2(random_amps * jnp.exp(1j * random_phases) * self.low_pass_filter)

def __call__(self, t, y, args):
phi = y["epw"]
E0 = y["E0"]
background_density = y["background_density"]
vte_sq = y["vte_sq"]

if self.cfg["terms"]["epw"]["linear"]:
# linear propagation
phi = jnp.fft.ifft2(jnp.fft.fft2(phi) * jnp.exp(-1j * 1.5 * vte_sq[0, 0] / self.wp0 * self.k_sq * self.dt))

# tpd
if self.cfg["terms"]["epw"]["source"]["tpd"]:
phi = phi + self.dt * self.tpd(t, phi, args={"E0": E0})

# density gradient
if self.cfg["terms"]["epw"]["density_gradient"]:
ex, ey = self.calc_fields_from_phi(phi)
ex *= jnp.exp(-1j * self.wp0 / 2.0 * (1 - background_density / self.envelope_density) * self.dt)
ey *= jnp.exp(-1j * self.wp0 / 2.0 * (1 - background_density / self.envelope_density) * self.dt)
phi = self.calc_phi_from_fields(ex, ey)

if self.cfg["terms"]["epw"]["source"]["noise"]:
phi += self.dt * self.get_noise()

return phi


class FDChargeDensity:
def __init__(self, cfg):
self.cfg = cfg
self.wp0 = cfg["units"]["derived"]["wp0"]
self.envelope_density = cfg["units"]["envelope density"]
self.kx = cfg["grid"]["kx"]
self.ky = cfg["grid"]["ky"]
self.dx = cfg["grid"]["dx_norm"]
self.dy = cfg["grid"]["dy_norm"]
self.e = cfg["units"]["derived"]["e"]
self.me = cfg["units"]["derived"]["me"]
self.w0 = cfg["units"]["derived"]["w0"]
self.dt = cfg["grid"]["dt"]

def calc_fields_from_divE(self, divE):
"""
Calculates ex(x, y) and ey(x, y) from divE.
Args:
divE (jnp.array): divE(x, y)
Returns:
ex_ey (jnp.array): Vector field e(x, y, dir)
"""
phi = jnp.fft.fft2(divE)

ex_k = self.cfg["grid"]["kx"][:, None] * phi
ey_k = self.cfg["grid"]["ky"][None, :] * phi
return -1j * jnp.concatenate([jnp.fft.ifft2(ex_k)[..., None], jnp.fft.ifft2(ey_k)[..., None]], axis=-1)

def calc_divE_from_fields(self, ex_ey):
ex_k = jnp.fft.fft2(ex_ey[..., 0])
ey_k = jnp.fft.fft2(ex_ey[..., 1])
divE_k = 1j * (self.kx[:, None] * ex_k + self.ky[None, :] * ey_k)
return jnp.fft.ifft2(divE_k)

def linear_propagate(self, divE, vte_sq, background_density):
padded_divE = jnp.pad(divE, ((1, 1), (1, 1)), mode="wrap")
d_divE = (
1.5
* 1j
* vte_sq[0, 0]
/ self.wp0
* (
(padded_divE[2:, 1:-1] - 2 * padded_divE[1:-1, 1:-1] + padded_divE[:-2, 1:-1]) / self.dx**2
+ (padded_divE[1:-1, 2:] - 2 * padded_divE[1:-1, 1:-1] + padded_divE[1:-1, :-2]) / self.dy**2
)
)
d_divE -= 0.5 * 1j * self.wp0 * (background_density / self.envelope_density - 1) * divE
return d_divE

def tpd(self, t, E0, divE, ey):
E0_Ey = E0[..., 1] * jnp.conj(ey)
E0_rho = E0[..., 1] * jnp.conj(divE)
padded_E0_rho = jnp.pad(E0_rho, ((1, 1), (1, 1)), mode="wrap")
padded_E0_Ey = jnp.pad(E0_Ey, ((1, 1), (1, 1)), mode="wrap")
tpd = (
-1j
* self.e
/ (8 * self.wp0 * self.me)
* jnp.exp(-1j * (self.w0 - 2 * self.wp0) * t)
* (
(padded_E0_Ey[2:, 1:-1] - padded_E0_Ey[1:-1, 1:-1] + padded_E0_Ey[:-2, 1:-1]) / self.dx**2.0
+ (padded_E0_Ey[1:-1, 2:] - padded_E0_Ey[1:-1, 1:-1] + padded_E0_Ey[1:-1, :-2]) / self.dy**2.0
- (padded_E0_rho[1:-1, 2:] - padded_E0_rho[1:-1, :-2]) / self.dy / 2.0
)
)

return tpd

def density_gradient(self, ex, background_density):
grad_n = jnp.gradient(background_density, self.dx, axis=0) / self.envelope_density
return -0.5 * 1j * self.wp0 * grad_n * ex

def single_step(self, t, y, args) -> jnp.array:
"""
Steps the epw forward in time in the form of a divE equation
This is a finite-difference based approach.
All of this mostly happens in real space.
Args:
t (float): time
y (jnp.array): divE(x,y)
args (Dict[str, jnp.array]): additional arguments
Returns:
divE (jnp.array]: updated divE(x,y)
"""

# calculate ex and ey from divE
ex_ey = self.calc_fields_from_divE(y)

# iaw source

# apply boundary damping to e fields
# ex_ey *= self.cfg["grid"]["absorbing_boundaries"][..., None]

# convert fields back to divE
# divE = self.calc_divE_from_fields(ex_ey)

divE = y
# charge density linear update
ddivE = self.linear_propagate(divE=divE, vte_sq=args["vte_sq"], background_density=args["background_density"])

# SRS source term

# TPD source term
if self.cfg["terms"]["epw"]["source"]["tpd"]:
ddivE += self.tpd(t=t, E0=args["E0"], divE=divE, ey=ex_ey[..., 1])

# density gradient term
if self.cfg["terms"]["epw"]["density_gradient"]:
ddivE += self.density_gradient(ex_ey[..., 0], background_density=args["background_density"])

return ddivE

def __call__(self, t, y, args):
args = {"E0": y["E0"], "background_density": y["background_density"], "vte_sq": y["vte_sq"]}
divE = y["div_E"]

divE += self.dt * jnp.real(self.single_step(t, divE, args))
divE += 1j * self.dt * jnp.imag(self.single_step(t + self.dt * 0.5, divE, args))

return divE
Loading

0 comments on commit df1510a

Please sign in to comment.