Skip to content

Commit

Permalink
revert changes to interrogate and add missing docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ktangsali committed Nov 3, 2023
1 parent 42d2041 commit 23fb1db
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
"--ignore-magic", "--fail-under=99", "--exclude=['setup.py', 'test', 'build', 'docs']",
"--ignore-regex=['forward', 'backward', 'reset_parameters', 'extra_repr', 'MetaData', 'apply_activation','exec_activation']",
"--color", "--"]
exclude: ^modulus/internal/|^examples/
exclude: ^modulus/internal/

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.35.0
Expand Down
2 changes: 2 additions & 0 deletions examples/cfd/ahmed_body_mgn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def dgl_to_pyvista(graph: DGLGraph):


class AhmedBodyRollout:
"""MGN inference on Ahmed Body dataset"""

def __init__(self, wb, logger):
# set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
6 changes: 6 additions & 0 deletions examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


def plot_assembled(perm, darc):
"""Utility for plotting"""
headers = ["permeability", "darcy"]
plt.rcParams.update({"font.size": 28})
fig, ax = plt.subplots(1, 2, figsize=(15 * 2, 15), sharey=True)
Expand All @@ -52,6 +53,7 @@ def EvaluateModel(
parent_result: FloatTensor = None,
log: PythonLogger = None,
):
"""Utility for running inference on trained model"""
# define model and load weights
dist = DistributedManager()
log.info(f"evaluating model {model_name}")
Expand Down Expand Up @@ -111,6 +113,7 @@ def forward_eval(invars):


def AssembleSolutionToDict(cfg: DictConfig, perm: dict, darcy: dict, pos: dict):
"""Assemble solution to easily interpretable dict"""
dat, idx = {}, 0
for ii in range(perm["ref0"].shape[0]):
samp = str(ii)
Expand Down Expand Up @@ -142,6 +145,7 @@ def AssembleSolutionToDict(cfg: DictConfig, perm: dict, darcy: dict, pos: dict):


def AssembleToSingleField(cfg: DictConfig, dat: dict):
"""Assemble multiple fields to a single dict"""
ref_fac = cfg.ref_fac
glob_size = dat["0"]["ref0"]["0"]["darcy"].shape[0]
inset_size = dat["0"]["ref1"]["0"]["darcy"].shape[0]
Expand Down Expand Up @@ -175,6 +179,7 @@ def AssembleToSingleField(cfg: DictConfig, dat: dict):


def GetRelativeL2(pred, tar):
"""Compute L2 error"""
div = 1.0 / tar["darcy"].shape[0] * tar["darcy"].shape[1]
err = pred["darcy"] - tar["darcy"]

Expand All @@ -185,6 +190,7 @@ def GetRelativeL2(pred, tar):


def ComputeErrorNorm(cfg: DictConfig, pred_dict: dict, log: PythonLogger, ref0_pred):
"""Compute relative L2-norm of error"""
# assemble ref1 and ref2 solutions alongside gound truth to single scalar field
log.info("computing relative L2-norm of error...")
tar_dict = np.load(cfg.inference.inference_set, allow_pickle=True).item()["fields"]
Expand Down
4 changes: 4 additions & 0 deletions examples/cfd/gray_scott_rnn/gray_scott_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def prepare_data(
predict_nr_tsteps,
start_timestep,
):
"""Data pre-processing"""
if Path(output_data_path).is_file():
pass
else:
Expand Down Expand Up @@ -67,6 +68,7 @@ def prepare_data(


def validation_step(model, dataloader, epoch):
"""Validation Step"""
model.eval()

for data in dataloader:
Expand All @@ -89,6 +91,8 @@ def validation_step(model, dataloader, epoch):


class HDF5MapStyleDataset(Dataset):
"""Simple map-stype HDF5 dataset"""

def __init__(
self,
file_path,
Expand Down
4 changes: 4 additions & 0 deletions examples/cfd/navier_stokes_rnn/navier_stokes_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def prepare_data(
start_idx,
num_samples,
):
"""Data pre-processing"""
if Path(output_data_path).is_file():
pass
else:
Expand Down Expand Up @@ -73,6 +74,7 @@ def prepare_data(


def validation_step(model, dataloader, epoch):
"""Validation step"""
model.eval()

loss_epoch = 0
Expand All @@ -99,6 +101,8 @@ def validation_step(model, dataloader, epoch):


class HDF5MapStyleDataset(Dataset):
"""Simple map-style HDF5 dataset"""

def __init__(
self,
file_path,
Expand Down
27 changes: 27 additions & 0 deletions examples/generative/diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def open_url(


def constant(value, shape=None, dtype=None, device=None, memory_format=None):
"""Cached construction of constant tensors"""
value = np.asarray(value)
if shape is not None:
shape = tuple(shape)
Expand Down Expand Up @@ -591,6 +592,7 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
def nan_to_num(
input, nan=0.0, posinf=None, neginf=None, *, out=None
): # pylint: disable=redefined-builtin
"""Replace NaN/Inf with specified numerical values"""
assert isinstance(input, torch.Tensor)
if posinf is None:
posinf = torch.finfo(input.dtype).max
Expand All @@ -617,6 +619,10 @@ def nan_to_num(

@contextlib.contextmanager
def suppress_tracer_warnings():
"""
Context manager to temporarily suppress known warnings in torch.jit.trace().
Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
"""
flt = ("ignore", None, torch.jit.TracerWarning, None, 0)
warnings.filters.insert(0, flt)
yield
Expand All @@ -630,6 +636,11 @@ def suppress_tracer_warnings():


def assert_shape(tensor, ref_shape):
"""
Assert that the shape of a tensor matches the given list of integers.
None indicates that the size of a dimension is allowed to vary.
Performs symbolic assertion when used in torch.jit.trace().
"""
if tensor.ndim != len(ref_shape):
raise AssertionError(
f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
Expand Down Expand Up @@ -660,6 +671,8 @@ def assert_shape(tensor, ref_shape):


def profiled_function(fn):
"""Function decorator that calls torch.autograd.profiler.record_function()."""

def decorator(*args, **kwargs):
with torch.autograd.profiler.record_function(fn.__name__):
return fn(*args, **kwargs)
Expand All @@ -674,6 +687,11 @@ def decorator(*args, **kwargs):


class InfiniteSampler(torch.utils.data.Sampler):
"""
Sampler for torch.utils.data.DataLoader that loops over the dataset
indefinitely, shuffling items as it goes.
"""

def __init__(
self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5
):
Expand Down Expand Up @@ -714,17 +732,20 @@ def __iter__(self):


def params_and_buffers(module):
"""Get parameters and buffers of a nn.Module"""
assert isinstance(module, torch.nn.Module)
return list(module.parameters()) + list(module.buffers())


def named_params_and_buffers(module):
"""Get named parameters and buffers of a nn.Module"""
assert isinstance(module, torch.nn.Module)
return list(module.named_parameters()) + list(module.named_buffers())


@torch.no_grad()
def copy_params_and_buffers(src_module, dst_module, require_all=False):
"""Copy parameters and buffers from a source module to target module"""
assert isinstance(src_module, torch.nn.Module)
assert isinstance(dst_module, torch.nn.Module)
src_tensors = dict(named_params_and_buffers(src_module))
Expand All @@ -741,6 +762,10 @@ def copy_params_and_buffers(src_module, dst_module, require_all=False):

@contextlib.contextmanager
def ddp_sync(module, sync):
"""
Context manager for easily enabling/disabling DistributedDataParallel
synchronization.
"""
assert isinstance(module, torch.nn.Module)
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
yield
Expand All @@ -754,6 +779,7 @@ def ddp_sync(module, sync):


def check_ddp_consistency(module, ignore_regex=None):
"""Check DistributedDataParallel consistency across processes."""
assert isinstance(module, torch.nn.Module)
for name, tensor in named_params_and_buffers(module):
fullname = type(module).__name__ + "." + name
Expand All @@ -772,6 +798,7 @@ def check_ddp_consistency(module, ignore_regex=None):


def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
"""Print summary table of module hierarchy."""
assert isinstance(module, torch.nn.Module)
assert not isinstance(module, torch.jit.ScriptModule)
assert isinstance(inputs, (tuple, list))
Expand Down
4 changes: 4 additions & 0 deletions examples/weather/dataset_download/era5_mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, base_path: str, fs: fsspec.AbstractFileSystem = None):
self.metadata = self.get_metadata()

def get_metadata(self):
"""Get metadata"""
if self.fs.exists(self.metadata_file):
with self.fs.open(self.metadata_file, "r") as f:
try:
Expand All @@ -69,10 +70,12 @@ def get_metadata(self):
return metadata

def save_metadata(self):
"""Save metadata"""
with self.fs.open(self.metadata_file, "w") as f:
json.dump(self.metadata, f)

def chunk_exists(self, variable, year, month, hours, pressure_level):
"""Check if chunk exists"""
for chunk in self.metadata["chunks"]:
if (
chunk["variable"] == variable
Expand Down Expand Up @@ -155,6 +158,7 @@ def download_chunk(
return ds

def variable_to_zarr_name(self, variable: str, pressure_level: int = None):
"""convert variable to zarr name"""
# create zarr path for variable
zarr_path = f"{self.base_path}/{variable}"
if pressure_level:
Expand Down
2 changes: 2 additions & 0 deletions examples/weather/graphcast/train_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@


class BaseTrainer:
"""Trainer class"""

def __init__(self):
pass

Expand Down
2 changes: 2 additions & 0 deletions examples/weather/graphcast/train_graphcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@


class GraphCastTrainer(BaseTrainer):
"""GraphCast Trainer"""

def __init__(self, wb, dist, rank_zero_logger):
super().__init__()
self.dist = dist
Expand Down
2 changes: 2 additions & 0 deletions examples/weather/graphcast/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class Validation:
"""Run validation on GraphCast model"""

def __init__(self, model, dtype, dist, wb):
self.model = model
self.dtype = dtype
Expand Down

0 comments on commit 23fb1db

Please sign in to comment.