Skip to content

Commit

Permalink
More refactoring for use as a package (#75)
Browse files Browse the repository at this point in the history
* removing bandwidth code, adding a function to help calculate coherence, adding commit hash versioning

* needs a version
  • Loading branch information
joglekara authored Sep 18, 2024
1 parent d489391 commit ca36e6e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 76 deletions.
23 changes: 21 additions & 2 deletions adept/_lpse2d/core/laser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, Tuple
from jax import numpy as jnp
import numpy as np
from jax import numpy as jnp, Array


class Light:
Expand Down Expand Up @@ -35,3 +34,23 @@ def laser_update(self, t: float, y: jnp.ndarray, light_wave: Dict) -> Tuple[jnp.
# raise NotImplementedError

return E0

def calc_ey_at_one_point(self, t: float, density: Array, light_wave: Dict) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
This function is used to calculate the coherence time of the laser
:param t: time
:param y: state variables
:return: updated laser field
"""

wpe = self.w0 * jnp.sqrt(density)[None, 0, 0]
k0 = self.w0 / self.c * jnp.sqrt((1 + 0j + light_wave["delta_omega"]) ** 2 - wpe**2 / self.w0**2)
E0_static = (
(1 + 0j - wpe**2.0 / (self.w0 * (1 + light_wave["delta_omega"])) ** 2) ** -0.25
* self.E0_source
* jnp.sqrt(light_wave["intensities"])
* jnp.exp(1j * k0 * self.x[0] + 1j * light_wave["initial_phase"])
)
dE0y = E0_static * jnp.exp(-1j * light_wave["delta_omega"] * self.w0 * t)
return jnp.sum(dE0y, axis=0)
44 changes: 2 additions & 42 deletions adept/_lpse2d/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,55 +483,15 @@ def plot_kt(kfields, td):
# kx = kfields.coords["kx"].data


def post_process(result, cfg: Dict, td: str, args) -> Tuple[xr.Dataset, xr.Dataset]:
used_driver = args["drivers"]
import pickle

with open(os.path.join(td, "used_driver.pkl"), "wb") as fi:
pickle.dump(used_driver, fi)

if "E0" in used_driver:
dw_over_w = used_driver["E0"]["delta_omega"] # / cfg["units"]["derived"]["w0"] - 1
fig, ax = plt.subplots(1, 3, figsize=(13, 5), tight_layout=True)
ax[0].plot(dw_over_w, used_driver["E0"]["intensities"], "o")
ax[0].grid()
ax[0].set_xlabel(r"$\Delta \omega / \omega_0$", fontsize=14)
ax[0].set_ylabel("$|E|$", fontsize=14)
ax[1].semilogy(dw_over_w, used_driver["E0"]["intensities"], "o")
ax[1].grid()
ax[1].set_xlabel(r"$\Delta \omega / \omega_0$", fontsize=14)
ax[1].set_ylabel("$|E|$", fontsize=14)
ax[2].plot(dw_over_w, used_driver["E0"]["initial_phase"], "o")
ax[2].grid()
ax[2].set_xlabel(r"$\Delta \omega / \omega_0$", fontsize=14)
ax[2].set_ylabel(r"$\angle E$", fontsize=14)
plt.savefig(os.path.join(td, "driver_that_was_used.png"), bbox_inches="tight")
plt.close()

# numerator = np.fft.ifft2(np.abs(np.fft.fft2(result.ys["fields"]["epw"], axes=(1, 2))) ** 2, axes=(1, 2))
# denominator = 1 / tmax * np.sum(np.abs(result.ys["fields"]["epw"]) ** 2, axis=0) * dt
def post_process(result, cfg: Dict, td: str) -> Tuple[xr.Dataset, xr.Dataset]:

os.makedirs(os.path.join(td, "binary"))
kfields, fields = make_xarrays(cfg, result.ts, result.ys, td)

plot_fields(fields, td)
plot_kt(kfields, td)

dx = fields.coords["x (um)"].data[1] - fields.coords["x (um)"].data[0]
dy = fields.coords["y (um)"].data[1] - fields.coords["y (um)"].data[0]
dt = fields.coords["t (ps)"].data[1] - fields.coords["t (ps)"].data[0]

metrics = {}
tint = 5.0 # last tint ps
it = int(tint / dt)
total_esq = np.abs(fields["ex"][-it:].data) ** 2 + np.abs(fields["ey"][-it:].data ** 2) * dx * dy * dt
metrics[f"total_e_sq_last_{tint}_ps".replace(".", "p")] = float(np.sum(total_esq))
metrics[f"log10_total_e_sq_last_{tint}_ps".replace(".", "p")] = float(
np.log10(metrics[f"total_e_sq_last_{tint}_ps".replace(".", "p")])
)
metrics[f"growth_rate_last_{tint}_ps".replace(".", "p")] = float(np.mean(np.gradient(np.log(total_esq), dt)))

return {"k": kfields, "x": fields, "metrics": metrics}
return {"k": kfields, "x": fields, "metrics": {}}


def make_xarrays(cfg, ts, ys, td):
Expand Down
2 changes: 1 addition & 1 deletion adept/_lpse2d/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, cfg) -> None:
super().__init__(cfg)

def post_process(self, run_output: Dict, td: str) -> Dict:
return post_process(run_output["solver result"], self.cfg, td, run_output["args"])
return post_process(run_output["solver result"], self.cfg, td)

def write_units(self) -> Dict:
"""
Expand Down
44 changes: 13 additions & 31 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#!/usr/bin/env python

import os
import sys
import os, sys, subprocess

from setuptools import setup
from setuptools import setup, find_packages

here = os.path.abspath(os.path.dirname(__file__))
sys.path.append(here)
Expand All @@ -15,6 +14,14 @@
long_description = f.read()


# get the current git commit hash
def get_git_commit_hash():
try:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
except subprocess.CalledProcessError:
return "unknown"


setup(
# metadata
name="adept",
Expand All @@ -24,10 +31,10 @@
url="https://github.com/ergodicio/adept",
author="Archis Joglekar",
author_email="archis@ergodic.io",
version=1.0, # versioneer.get_version(),
version="0.0.1+" + get_git_commit_hash(),
# cmdclass=versioneer.get_cmdclass(),
packages=["adept"],
python_requires=">=3.8",
packages=find_packages(),
python_requires=">=3.10",
install_requires=[
"jax[cuda12]",
"diffrax",
Expand All @@ -48,29 +55,4 @@
"interpax",
"tabulate",
],
# extras_require={
# "dev": [
# "fastapi",
# "httpx", # required by fastapi test client
# "requests",
# "numpy",
# "pre-commit",
# "pytest",
# "pytest-cov",
# "sphinx",
# ],
# },
# package_data={
# "tesseract": [
# "templates/**/*",
# # Ensure tesseract_runtime folder is copied to site-packages when installing
# "../tesseract_runtime/**/*",
# ],
# },
# zip_safe=False,
# entry_points={
# "console_scripts": [
# "tesseract=tesseract.cli:entrypoint",
# ],
# },
)

0 comments on commit ca36e6e

Please sign in to comment.