-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented ADEPTModule as a better interface for the solvers (#46)
* adeptmodule and ergoexo * all two fluid tests passing * passing vfp tests * cleanup * passing vlasov1d * vlasov cleanup * module system -- working LPSE * bandwidth modules for LPSE and a twostream example for vlasov1d * log system metrics
- Loading branch information
Showing
42 changed files
with
2,448 additions
and
933 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,266 @@ | ||
from adept import tf1d | ||
from typing import Dict, Tuple, Callable | ||
import jax.flatten_util | ||
import os, time, tempfile, yaml, pickle | ||
|
||
|
||
from diffrax import Solution, Euler, RESULTS | ||
import mlflow, jax, numpy as np | ||
from jax import numpy as jnp | ||
|
||
|
||
from utils import misc | ||
|
||
|
||
def get_envelope(p_wL, p_wR, p_L, p_R, ax): | ||
return 0.5 * (jnp.tanh((ax - p_L) / p_wL) - jnp.tanh((ax - p_R) / p_wR)) | ||
|
||
|
||
class Stepper(Euler): | ||
""" | ||
This is just a dummy stepper | ||
:param cfg: | ||
""" | ||
|
||
def step(self, terms, t0, t1, y0, args, solver_state, made_jump): | ||
del solver_state, made_jump | ||
y1 = terms.vf(t0, y0, args) | ||
dense_info = dict(y0=y0, y1=y1) | ||
return y1, None, dense_info, None, RESULTS.successful | ||
|
||
|
||
class ADEPTModule: | ||
""" | ||
This class is the base class for all the ADEPT modules. It defines the interface that all the ADEPT modules must implement. | ||
Args: | ||
cfg: The configuration dictionary | ||
""" | ||
|
||
def __init__(self, cfg) -> None: | ||
self.cfg = cfg | ||
|
||
def post_process(self, run_output: Dict, td: str): | ||
pass | ||
|
||
def write_units(self) -> Dict: | ||
return {} | ||
|
||
def init_diffeqsolve(self): | ||
pass | ||
|
||
def get_derived_quantities(self): | ||
pass | ||
|
||
def get_solver_quantities(self): | ||
pass | ||
|
||
def get_save_func(self): | ||
pass | ||
|
||
def init_state_and_args(self) -> Dict: | ||
return {} | ||
|
||
def init_modules(self) -> Dict: | ||
return {} | ||
|
||
def __call__(self, trainable_modules: Dict, args: Dict): | ||
return {} | ||
|
||
def vg(self, trainable_modules: Dict, args: Dict): | ||
raise NotImplementedError( | ||
"This is the base class and does not have a gradient implemented. This is " | ||
+ "likely because there is no metric in place. Subclass this class and implement the gradient" | ||
) | ||
# return eqx.filter_value_and_grad(self.__call__)(trainable_modules) | ||
|
||
|
||
class ergoExo: | ||
""" | ||
This class is the main interface for running a simulation. It is responsible for calling all the ADEPT modules in the right order | ||
and logging parameters and results to mlflow. | ||
This helps decouple the numerical solvers from the experiment management | ||
""" | ||
|
||
def __init__(self, mlflow_run_id: str = None, mlflow_nested: bool = None) -> None: | ||
|
||
self.mlflow_run_id = mlflow_run_id | ||
# if mlflow_run_id is not None: | ||
# assert self.mlflow_nested is not None | ||
if mlflow_nested is None: | ||
self.mlflow_nested = False | ||
else: | ||
self.mlflow_nested = mlflow_nested | ||
|
||
if "BASE_TEMPDIR" in os.environ: | ||
self.base_tempdir = os.environ["BASE_TEMPDIR"] | ||
else: | ||
self.base_tempdir = None | ||
|
||
self.ran_setup = False | ||
|
||
def get_adept_module(self, cfg: Dict) -> ADEPTModule: | ||
""" | ||
This function returns the helper functions for the given solver | ||
Args: | ||
solver: The solver to use | ||
""" | ||
if cfg["solver"] == "tf-1d": | ||
from adept.tf1d.base import BaseTwoFluid1D as this_module | ||
# elif solver == "sh-2d": | ||
# from adept.sh2d import helpers | ||
elif cfg["solver"] == "vlasov-1d": | ||
from adept.vlasov1d.base import BaseVlasov1D as this_module | ||
# elif solver == "vlasov-1d2v": | ||
# from adept.vlasov1d2v import helpers | ||
# elif solver == "vlasov-2d": | ||
# from adept.vlasov2d import helpers | ||
elif cfg["solver"] == "envelope-2d": | ||
from adept.lpse2d.base import BaseLPSE2D as this_module | ||
elif cfg["solver"] == "vfp-1d": | ||
from adept.vfp1d.base import BaseVFP1D as this_module | ||
else: | ||
raise NotImplementedError("This solver approach has not been implemented yet") | ||
|
||
return this_module(cfg) | ||
|
||
def _setup_(self, cfg: Dict, td: str, adept_module: ADEPTModule = None): | ||
if adept_module is None: | ||
self.adept_module = self.get_adept_module(cfg) | ||
else: | ||
self.adept_module = adept_module | ||
|
||
# dump raw config | ||
with open(os.path.join(td, "config.yaml"), "w") as fi: | ||
yaml.dump(self.adept_module.cfg, fi) | ||
|
||
# dump units | ||
quants_dict = self.adept_module.write_units() # writes the units to the temporary directory | ||
with open(os.path.join(td, "units.yaml"), "w") as fi: | ||
yaml.dump(quants_dict, fi) | ||
|
||
# dump derived config | ||
self.adept_module.get_derived_quantities() # gets the derived quantities | ||
misc.log_params(self.adept_module.cfg) # logs the parameters to mlflow | ||
with open(os.path.join(td, "derived_config.yaml"), "w") as fi: | ||
yaml.dump(self.adept_module.cfg, fi) | ||
|
||
# dump array config | ||
self.adept_module.get_solver_quantities() | ||
with open(os.path.join(td, "array_config.yaml"), "wb") as fi: | ||
pickle.dump(self.adept_module.cfg, fi) | ||
|
||
self.adept_module.init_state_and_args() | ||
self.adept_module.init_diffeqsolve() | ||
modules = self.adept_module.init_modules() | ||
|
||
self.ran_setup = True | ||
|
||
return modules | ||
|
||
def setup(self, cfg: Dict, adept_module: ADEPTModule = None) -> Dict: | ||
""" | ||
This function sets up the simulation by getting the helper functions for the given solver | ||
Args: | ||
cfg: The configuration dictionary | ||
""" | ||
|
||
with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: | ||
if self.mlflow_run_id is None: | ||
mlflow.set_experiment(cfg["mlflow"]["experiment"]) | ||
with mlflow.start_run(run_name=cfg["mlflow"]["run"], nested=self.mlflow_nested) as mlflow_run: | ||
modules = self._setup_(cfg, td, adept_module) | ||
self.mlflow_run_id = mlflow_run.info.run_id | ||
|
||
else: | ||
with mlflow.start_run(run_id=self.mlflow_run_id, nested=self.mlflow_nested) as mlflow_run: | ||
with tempfile.TemporaryDirectory(dir=self.base_tempdir) as temp_path: | ||
cfg = misc.get_cfg(artifact_uri=mlflow_run.info.artifact_uri, temp_path=temp_path) | ||
modules = self._setup_(cfg, td, adept_module) | ||
|
||
return modules | ||
|
||
def __call__(self, modules: Dict = None) -> Tuple[Solution, Dict, str]: | ||
""" | ||
This function is the main entry point for running a simulation. It takes a configuration dictionary and returns a | ||
``diffrax.Solution`` object and a dictionary of datasets. | ||
Returns: | ||
A tuple of a Solution object, a dictionary of ``xarray.dataset``s, and the mlflow run id | ||
""" | ||
|
||
assert self.ran_setup, "You must run self.setup() before running the simulation" | ||
|
||
with mlflow.start_run( | ||
run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True | ||
) as mlflow_run: | ||
t0 = time.time() | ||
run_output = self.adept_module(modules, None) | ||
mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow | ||
|
||
t0 = time.time() | ||
with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: | ||
post_processing_output = self.adept_module.post_process(run_output, td) | ||
mlflow.log_artifacts(td) # logs the temporary directory to mlflow | ||
|
||
if "metrics" in post_processing_output: | ||
mlflow.log_metrics(post_processing_output["metrics"]) | ||
mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) | ||
|
||
return run_output, post_processing_output, self.mlflow_run_id | ||
|
||
def val_and_grad(self, modules: Dict = None): | ||
""" | ||
This function is the value and gradient of the simulation. It assumes that this function has been implemented. | ||
Args: | ||
modules: The parameters to run the simulation and take the gradient against. All the other parameters are | ||
static | ||
Returns: val - The value of the simulation, grad - The gradient of the simulation with respect to the parameters, and the simulation output | ||
""" | ||
assert self.ran_setup, "You must run self.setup() before running the simulation" | ||
with mlflow.start_run( | ||
run_id=self.mlflow_run_id, nested=self.mlflow_nested, log_system_metrics=True | ||
) as mlflow_run: | ||
t0 = time.time() | ||
(val, run_output), grad = self.adept_module.vg(modules, None) | ||
flattened_grad, _ = jax.flatten_util.ravel_pytree(grad) | ||
mlflow.log_metrics({"run_time": round(time.time() - t0, 4)}) # logs the run time to mlflow | ||
mlflow.log_metrics({"val": float(val), "l2-grad": float(np.linalg.norm(flattened_grad))}) | ||
|
||
t0 = time.time() | ||
with tempfile.TemporaryDirectory(dir=self.base_tempdir) as td: | ||
post_processing_output = self.adept_module.post_process(run_output, td) | ||
mlflow.log_artifacts(td) # logs the temporary directory to mlflow | ||
if "metrics" in post_processing_output: | ||
mlflow.log_metrics(post_processing_output["metrics"]) | ||
mlflow.log_metrics({"postprocess_time": round(time.time() - t0, 4)}) | ||
return val, grad, (run_output, post_processing_output, self.mlflow_run_id) | ||
|
||
def _log_flops_(_run_: Callable, models: Dict, state: Dict, args: Dict, tqs): | ||
""" | ||
Logs the number of flops to mlflow | ||
Args: | ||
_run_: The function that runs the simulation | ||
models: The models used in the simulation | ||
tqs: The time quantities used in the simulation | ||
""" | ||
wrapped = jax.xla_computation(_run_) | ||
computation = wrapped(models, state, args, tqs) | ||
module = computation.as_hlo_module() | ||
client = jax.lib.xla_bridge.get_backend() | ||
analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, module) | ||
flops_sum = analysis["flops"] | ||
mlflow.log_metrics({"total GigaFLOP": flops_sum / 1e9}) # logs the flops to mlflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import Dict | ||
import numpy as np | ||
from astropy.units import Quantity as _Q | ||
from diffrax import diffeqsolve, SaveAt, ODETerm | ||
from equinox import filter_jit | ||
|
||
from adept import ADEPTModule, Stepper | ||
from adept.lpse2d.helpers import ( | ||
write_units, | ||
post_process, | ||
get_derived_quantities, | ||
get_solver_quantities, | ||
get_save_quantities, | ||
get_density_profile, | ||
) | ||
from adept.lpse2d.vector_field import SplitStep | ||
from adept.lpse2d.modules.driver import BandwidthModule | ||
|
||
|
||
class BaseLPSE2D(ADEPTModule): | ||
def __init__(self, cfg) -> None: | ||
super().__init__(cfg) | ||
|
||
def post_process(self, run_output: Dict, td: str) -> Dict: | ||
return post_process(run_output["solver result"], self.cfg, td, run_output["args"]) | ||
|
||
def write_units(self) -> Dict: | ||
""" | ||
Write the units to a file | ||
:param cfg: | ||
:param td: | ||
:return: cfg | ||
""" | ||
return write_units(self.cfg) | ||
|
||
def get_derived_quantities(self): | ||
self.cfg = get_derived_quantities(self.cfg) | ||
|
||
def get_solver_quantities(self): | ||
self.cfg["grid"] = get_solver_quantities(self.cfg) | ||
|
||
def init_modules(self) -> Dict: | ||
return {"bandwidth": BandwidthModule(self.cfg)} | ||
|
||
def init_diffeqsolve(self): | ||
|
||
self.cfg = get_save_quantities(self.cfg) | ||
self.time_quantities = { | ||
"t0": 0.0, | ||
"t1": self.cfg["grid"]["tmax"], | ||
"max_steps": self.cfg["grid"]["max_steps"], | ||
"save_t0": 0.0, | ||
"save_t1": self.cfg["grid"]["tmax"], | ||
"save_nt": self.cfg["grid"]["tmax"], | ||
} | ||
|
||
self.diffeqsolve_quants = dict( | ||
terms=ODETerm(SplitStep(self.cfg)), | ||
solver=Stepper(), | ||
saveat=dict(ts=self.cfg["save"]["t"]["ax"], fn=self.cfg["save"]["func"]), | ||
) | ||
|
||
def init_state_and_args(self) -> Dict: | ||
if self.cfg["density"]["noise"]["type"] == "uniform": | ||
random_amps = np.random.uniform( | ||
self.cfg["density"]["noise"]["min"], | ||
self.cfg["density"]["noise"]["max"], | ||
(self.cfg["grid"]["nx"], self.cfg["grid"]["ny"]), | ||
) | ||
|
||
elif self.cfg["density"]["noise"]["type"] == "normal": | ||
loc = 0.5 * (self.cfg["density"]["noise"]["min"] + self.cfg["density"]["noise"]["max"]) | ||
scale = 1.0 | ||
random_amps = np.random.normal(loc, scale, (self.cfg["grid"]["nx"], self.cfg["grid"]["ny"])) | ||
|
||
else: | ||
raise NotImplementedError | ||
|
||
random_phases = np.random.uniform(0, 2 * np.pi, (self.cfg["grid"]["nx"], self.cfg["grid"]["ny"])) | ||
phi_noise = 1 * np.exp(1j * random_phases) | ||
epw = 0 * phi_noise | ||
|
||
background_density = get_density_profile(self.cfg) | ||
vte_sq = np.ones((self.cfg["grid"]["nx"], self.cfg["grid"]["ny"])) * self.cfg["units"]["derived"]["vte"] ** 2 | ||
E0 = np.zeros((self.cfg["grid"]["nx"], self.cfg["grid"]["ny"], 2), dtype=np.complex128) | ||
state = {"background_density": background_density, "epw": epw, "E0": E0, "vte_sq": vte_sq} | ||
|
||
# drivers = assemble_bandwidth(self.cfg) | ||
self.state = {k: v.view(dtype=np.float64) for k, v in state.items()} | ||
self.args = {"drivers": {"E0": {}}} | ||
|
||
@filter_jit | ||
def __call__(self, trainable_modules: Dict, args: Dict = None) -> Dict: | ||
|
||
if args is None: | ||
args = self.args | ||
|
||
for name, module in trainable_modules.items(): | ||
state, args = module(self.state, args) | ||
|
||
solver_result = diffeqsolve( | ||
terms=self.diffeqsolve_quants["terms"], | ||
solver=self.diffeqsolve_quants["solver"], | ||
t0=self.time_quantities["t0"], | ||
t1=self.time_quantities["t1"], | ||
max_steps=self.cfg["grid"]["max_steps"], | ||
dt0=self.cfg["grid"]["dt"], | ||
y0=state, | ||
args=args, | ||
saveat=SaveAt(**self.diffeqsolve_quants["saveat"]), | ||
) | ||
|
||
return {"solver result": solver_result, "args": args} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.