Skip to content

Commit

Permalink
Fix helmholtz initializers (#212)
Browse files Browse the repository at this point in the history
fixed missing default params for helmholtz solver
  • Loading branch information
astanziola authored Sep 25, 2023
1 parent 83a6050 commit 2c0f9aa
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 105 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased]
### Added
- Added `numbers_with_smallest_primes` utility to find grids with small primes for efficient FFT when using FourierSeries
### Fixed
- Restored `default_params` for the helmholtz operators that wen missing since the last jaxdf update

## [0.1.4] - 2023-06-29
### Changed
Expand Down
123 changes: 67 additions & 56 deletions docs/notebooks/harmonic/helmholtz_solver_differentiable.ipynb

Large diffs are not rendered by default.

103 changes: 54 additions & 49 deletions jwave/acoustics/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def laplacian_with_pml(u: Continuous,
grad_u = gradient(u)
mod_grad_u = grad_u * pml
mod_diag_jacobian = diag_jacobian(mod_grad_u) * pml
return sum_over_dims(mod_diag_jacobian), None
return sum_over_dims(mod_diag_jacobian)


@operator
Expand Down Expand Up @@ -92,10 +92,31 @@ def laplacian_with_pml(u: OnGrid,
rho_u = sum_over_dims(mod_grad_u * grad_rho0) / rho0

# Put everything together
return nabla_u - rho_u, None
return nabla_u - rho_u


@operator
def on_grid_pml_init(u: OnGrid, medium: Medium, omega, *args, **kwargs):
return [
u.replace_params(
complex_pml_on_grid(medium, omega, shift=u.domain.dx[0] / 2)),
u.replace_params(
complex_pml_on_grid(medium, omega, shift=-u.domain.dx[0] / 2)),
]


def fd_laplacian_with_pml_init(u: FiniteDifferences, medium: Medium, omega,
*args, **kwargs):
return {
"pml_on_grid": on_grid_pml_init(u, medium, omega),
"stencils": {
"gradient": gradient.default_params(u, stagger=[0.5]),
"gradient_unstaggered": gradient.default_params(u),
"diag_jacobian": diag_jacobian.default_params(u, stagger=[-0.5]),
},
}


@operator(init_params=fd_laplacian_with_pml_init)
def laplacian_with_pml(u: FiniteDifferences,
medium: Medium,
*,
Expand All @@ -113,25 +134,6 @@ def laplacian_with_pml(u: FiniteDifferences,
FiniteDifferences: Modified Laplacian operator applied to `u`.
"""
rho0 = medium.density
if params == None:
params = {
"pml_on_grid": [
u.replace_params(
complex_pml_on_grid(medium,
omega,
shift=u.domain.dx[0] / 2)),
u.replace_params(
complex_pml_on_grid(medium,
omega,
shift=-u.domain.dx[0] / 2)),
],
"stencils": {
"gradient": gradient.default_params(u, stagger=[0.5]),
"gradient_unstaggered": gradient.default_params(u),
"diag_jacobian": diag_jacobian.default_params(u,
stagger=[-0.5]),
},
}

pml = params["pml_on_grid"]
stencils = params["stencils"]
Expand All @@ -154,10 +156,18 @@ def laplacian_with_pml(u: FiniteDifferences,
params=stencils["gradient_unstaggered"])
rho_u = sum_over_dims(mod_grad_u * grad_rho0) / rho0

return nabla_u - rho_u, params
return nabla_u - rho_u


@operator
def fourier_laplacian_with_pml_init(u: FourierSeries, medium: Medium, omega,
*args, **kwargs):
return {
"pml_on_grid": on_grid_pml_init(u, medium, omega),
"fft_u": gradient.default_params(u),
}


@operator(init_params=fourier_laplacian_with_pml_init)
def laplacian_with_pml(u: FourierSeries,
medium: Medium,
*,
Expand All @@ -176,23 +186,6 @@ def laplacian_with_pml(u: FourierSeries,
"""
rho0 = medium.density

# Initialize pml parameters if not provided
if params == None:
params = {
"pml_on_grid": [
u.replace_params(
complex_pml_on_grid(medium,
omega,
shift=u.domain.dx[0] / 2)),
u.replace_params(
complex_pml_on_grid(medium,
omega,
shift=-u.domain.dx[0] / 2)),
],
"fft_u":
gradient.default_params(u),
}

pml = params["pml_on_grid"]

# Making laplacian
Expand Down Expand Up @@ -225,7 +218,7 @@ def laplacian_with_pml(u: FourierSeries,
rho_u = sum_over_dims(_ru) / rho0

# Put everything together
return nabla_u - rho_u, params
return nabla_u - rho_u


@operator
Expand All @@ -246,10 +239,17 @@ def wavevector(u: Field, medium: Medium, *, omega=1.0, params=None) -> Field:
trans_fun = lambda x: db2neper(x, 2.0)
alpha = compose(alpha)(trans_fun)
k_mod = (omega / c)**2 + 2j * (omega**3) * alpha / c
return u * k_mod, None
return u * k_mod


@operator
def helmholtz_init_params(u: Field, medium: Medium, omega, *args, **kwargs):
return {
"laplacian": laplacian_with_pml.default_params(u, medium, omega=omega),
"wavevector": wavevector.default_params(u, medium, omega=omega),
}


@operator(init_params=helmholtz_init_params)
def helmholtz(u: Field, medium: Medium, *, omega=1.0, params=None) -> Field:
r"""Evaluates the Helmholtz operator on a field $u$ with a PML.
Expand All @@ -262,15 +262,22 @@ def helmholtz(u: Field, medium: Medium, *, omega=1.0, params=None) -> Field:
Returns:
Field: Helmholtz operator applied to `u`.
"""
lapl_params, wavevector_params = params["laplacian"], params["wavevector"]

# Get the modified laplacian
L = laplacian_with_pml(u, medium, omega)
L = laplacian_with_pml(u, medium, omega, params=lapl_params)

# Add the wavenumber term
k = wavevector(u, medium, omega)
k = wavevector(u, medium, omega, params=wavevector_params)
return L + k, None


@operator
def ongrid_helmholtz_init_params(u: OnGrid, medium: Medium, omega, *args,
**kwargs):
return laplacian_with_pml.default_params(u, medium, omega=omega)


@operator(init_params=ongrid_helmholtz_init_params)
def helmholtz(u: OnGrid, medium: Medium, *, omega=1.0, params=None) -> OnGrid:
r"""Evaluates the Helmholtz operator on a field $u$ with a PML. This
implementation exposes the laplacian parameters to the user.
Expand All @@ -284,8 +291,6 @@ def helmholtz(u: OnGrid, medium: Medium, *, omega=1.0, params=None) -> OnGrid:
Returns:
OnGrid: Helmholtz operator applied to `u`.
"""
if params == None:
params = laplacian_with_pml.default_params(u, medium, omega=omega)

# Get the modified laplacian
L = laplacian_with_pml(u, medium, omega=omega, params=params)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_helmholtz.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from jax import numpy as jnp

from jwave.acoustics.operators import helmholtz
from jwave.acoustics.time_harmonic import helmholtz_solver
from jwave.geometry import Domain, FourierSeries, Medium

Expand All @@ -40,5 +41,22 @@ def test_if_homog_helmholtz_runs():
)


def test_default_params():
N = (128, 128)
domain = Domain(N, (1.0, 1.0))
field = jnp.zeros(N).astype(jnp.complex64)
field = FourierSeries(field, domain)
sos = jnp.ones(N)
sos = FourierSeries(sos, domain)

medium = Medium(domain, sound_speed=sos, pml_size=15)

default_params = helmholtz.default_params(field, medium, omega=1.0)

# Check that 'pml_on_grid', 'fft_u' are in the dict
assert 'pml_on_grid' in default_params.keys()
assert 'fft_u' in default_params.keys()


if __name__ == "__main__":
test_if_homog_helmholtz_runs()

0 comments on commit 2c0f9aa

Please sign in to comment.