diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26d0ea10dd..2194cd4790 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: "--ignore-magic", "--fail-under=99", "--exclude=['setup.py', 'test', 'build', 'docs']", "--ignore-regex=['forward', 'backward', 'reset_parameters', 'extra_repr', 'MetaData', 'apply_activation','exec_activation']", "--color", "--"] - exclude: ^modulus/internal/|^examples/ + exclude: ^modulus/internal/ - repo: https://github.com/igorshubovych/markdownlint-cli rev: v0.35.0 diff --git a/examples/cfd/ahmed_body_mgn/inference.py b/examples/cfd/ahmed_body_mgn/inference.py index 5dfc1f5f0f..a34dd564c7 100644 --- a/examples/cfd/ahmed_body_mgn/inference.py +++ b/examples/cfd/ahmed_body_mgn/inference.py @@ -102,6 +102,8 @@ def dgl_to_pyvista(graph: DGLGraph): class AhmedBodyRollout: + """MGN inference on Ahmed Body dataset""" + def __init__(self, wb, logger): # set device self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py b/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py index eabe6851b9..212abc6878 100644 --- a/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py +++ b/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py @@ -31,6 +31,7 @@ def plot_assembled(perm, darc): + """Utility for plotting""" headers = ["permeability", "darcy"] plt.rcParams.update({"font.size": 28}) fig, ax = plt.subplots(1, 2, figsize=(15 * 2, 15), sharey=True) @@ -52,6 +53,7 @@ def EvaluateModel( parent_result: FloatTensor = None, log: PythonLogger = None, ): + """Utility for running inference on trained model""" # define model and load weights dist = DistributedManager() log.info(f"evaluating model {model_name}") @@ -111,6 +113,7 @@ def forward_eval(invars): def AssembleSolutionToDict(cfg: DictConfig, perm: dict, darcy: dict, pos: dict): + """Assemble solution to easily interpretable dict""" dat, idx = {}, 0 for ii in range(perm["ref0"].shape[0]): samp = str(ii) @@ -142,6 +145,7 @@ def AssembleSolutionToDict(cfg: DictConfig, perm: dict, darcy: dict, pos: dict): def AssembleToSingleField(cfg: DictConfig, dat: dict): + """Assemble multiple fields to a single dict""" ref_fac = cfg.ref_fac glob_size = dat["0"]["ref0"]["0"]["darcy"].shape[0] inset_size = dat["0"]["ref1"]["0"]["darcy"].shape[0] @@ -175,6 +179,7 @@ def AssembleToSingleField(cfg: DictConfig, dat: dict): def GetRelativeL2(pred, tar): + """Compute L2 error""" div = 1.0 / tar["darcy"].shape[0] * tar["darcy"].shape[1] err = pred["darcy"] - tar["darcy"] @@ -185,6 +190,7 @@ def GetRelativeL2(pred, tar): def ComputeErrorNorm(cfg: DictConfig, pred_dict: dict, log: PythonLogger, ref0_pred): + """Compute relative L2-norm of error""" # assemble ref1 and ref2 solutions alongside gound truth to single scalar field log.info("computing relative L2-norm of error...") tar_dict = np.load(cfg.inference.inference_set, allow_pickle=True).item()["fields"] diff --git a/examples/cfd/gray_scott_rnn/gray_scott_rnn.py b/examples/cfd/gray_scott_rnn/gray_scott_rnn.py index 04cfd8da37..1b922f23f3 100644 --- a/examples/cfd/gray_scott_rnn/gray_scott_rnn.py +++ b/examples/cfd/gray_scott_rnn/gray_scott_rnn.py @@ -37,6 +37,7 @@ def prepare_data( predict_nr_tsteps, start_timestep, ): + """Data pre-processing""" if Path(output_data_path).is_file(): pass else: @@ -67,6 +68,7 @@ def prepare_data( def validation_step(model, dataloader, epoch): + """Validation Step""" model.eval() for data in dataloader: @@ -89,6 +91,8 @@ def validation_step(model, dataloader, epoch): class HDF5MapStyleDataset(Dataset): + """Simple map-stype HDF5 dataset""" + def __init__( self, file_path, diff --git a/examples/cfd/navier_stokes_rnn/navier_stokes_rnn.py b/examples/cfd/navier_stokes_rnn/navier_stokes_rnn.py index 3045cda368..9b233ccf8b 100644 --- a/examples/cfd/navier_stokes_rnn/navier_stokes_rnn.py +++ b/examples/cfd/navier_stokes_rnn/navier_stokes_rnn.py @@ -40,6 +40,7 @@ def prepare_data( start_idx, num_samples, ): + """Data pre-processing""" if Path(output_data_path).is_file(): pass else: @@ -73,6 +74,7 @@ def prepare_data( def validation_step(model, dataloader, epoch): + """Validation step""" model.eval() loss_epoch = 0 @@ -99,6 +101,8 @@ def validation_step(model, dataloader, epoch): class HDF5MapStyleDataset(Dataset): + """Simple map-style HDF5 dataset""" + def __init__( self, file_path, diff --git a/examples/generative/diffusion/utils.py b/examples/generative/diffusion/utils.py index 2d0c00d502..5b0241c986 100644 --- a/examples/generative/diffusion/utils.py +++ b/examples/generative/diffusion/utils.py @@ -552,6 +552,7 @@ def open_url( def constant(value, shape=None, dtype=None, device=None, memory_format=None): + """Cached construction of constant tensors""" value = np.asarray(value) if shape is not None: shape = tuple(shape) @@ -591,6 +592,7 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None): def nan_to_num( input, nan=0.0, posinf=None, neginf=None, *, out=None ): # pylint: disable=redefined-builtin + """Replace NaN/Inf with specified numerical values""" assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max @@ -617,6 +619,10 @@ def nan_to_num( @contextlib.contextmanager def suppress_tracer_warnings(): + """ + Context manager to temporarily suppress known warnings in torch.jit.trace(). + Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + """ flt = ("ignore", None, torch.jit.TracerWarning, None, 0) warnings.filters.insert(0, flt) yield @@ -630,6 +636,11 @@ def suppress_tracer_warnings(): def assert_shape(tensor, ref_shape): + """ + Assert that the shape of a tensor matches the given list of integers. + None indicates that the size of a dimension is allowed to vary. + Performs symbolic assertion when used in torch.jit.trace(). + """ if tensor.ndim != len(ref_shape): raise AssertionError( f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" @@ -660,6 +671,8 @@ def assert_shape(tensor, ref_shape): def profiled_function(fn): + """Function decorator that calls torch.autograd.profiler.record_function().""" + def decorator(*args, **kwargs): with torch.autograd.profiler.record_function(fn.__name__): return fn(*args, **kwargs) @@ -674,6 +687,11 @@ def decorator(*args, **kwargs): class InfiniteSampler(torch.utils.data.Sampler): + """ + Sampler for torch.utils.data.DataLoader that loops over the dataset + indefinitely, shuffling items as it goes. + """ + def __init__( self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 ): @@ -714,17 +732,20 @@ def __iter__(self): def params_and_buffers(module): + """Get parameters and buffers of a nn.Module""" assert isinstance(module, torch.nn.Module) return list(module.parameters()) + list(module.buffers()) def named_params_and_buffers(module): + """Get named parameters and buffers of a nn.Module""" assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) @torch.no_grad() def copy_params_and_buffers(src_module, dst_module, require_all=False): + """Copy parameters and buffers from a source module to target module""" assert isinstance(src_module, torch.nn.Module) assert isinstance(dst_module, torch.nn.Module) src_tensors = dict(named_params_and_buffers(src_module)) @@ -741,6 +762,10 @@ def copy_params_and_buffers(src_module, dst_module, require_all=False): @contextlib.contextmanager def ddp_sync(module, sync): + """ + Context manager for easily enabling/disabling DistributedDataParallel + synchronization. + """ assert isinstance(module, torch.nn.Module) if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): yield @@ -754,6 +779,7 @@ def ddp_sync(module, sync): def check_ddp_consistency(module, ignore_regex=None): + """Check DistributedDataParallel consistency across processes.""" assert isinstance(module, torch.nn.Module) for name, tensor in named_params_and_buffers(module): fullname = type(module).__name__ + "." + name @@ -772,6 +798,7 @@ def check_ddp_consistency(module, ignore_regex=None): def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + """Print summary table of module hierarchy.""" assert isinstance(module, torch.nn.Module) assert not isinstance(module, torch.jit.ScriptModule) assert isinstance(inputs, (tuple, list)) diff --git a/examples/weather/dataset_download/era5_mirror.py b/examples/weather/dataset_download/era5_mirror.py index 21da16dca2..dbcc04988a 100644 --- a/examples/weather/dataset_download/era5_mirror.py +++ b/examples/weather/dataset_download/era5_mirror.py @@ -58,6 +58,7 @@ def __init__(self, base_path: str, fs: fsspec.AbstractFileSystem = None): self.metadata = self.get_metadata() def get_metadata(self): + """Get metadata""" if self.fs.exists(self.metadata_file): with self.fs.open(self.metadata_file, "r") as f: try: @@ -69,10 +70,12 @@ def get_metadata(self): return metadata def save_metadata(self): + """Save metadata""" with self.fs.open(self.metadata_file, "w") as f: json.dump(self.metadata, f) def chunk_exists(self, variable, year, month, hours, pressure_level): + """Check if chunk exists""" for chunk in self.metadata["chunks"]: if ( chunk["variable"] == variable @@ -155,6 +158,7 @@ def download_chunk( return ds def variable_to_zarr_name(self, variable: str, pressure_level: int = None): + """convert variable to zarr name""" # create zarr path for variable zarr_path = f"{self.base_path}/{variable}" if pressure_level: diff --git a/examples/weather/graphcast/train_base.py b/examples/weather/graphcast/train_base.py index e54674314a..314c713a21 100644 --- a/examples/weather/graphcast/train_base.py +++ b/examples/weather/graphcast/train_base.py @@ -25,6 +25,8 @@ class BaseTrainer: + """Trainer class""" + def __init__(self): pass diff --git a/examples/weather/graphcast/train_graphcast.py b/examples/weather/graphcast/train_graphcast.py index 2c0e28dd40..0fa4f9080a 100644 --- a/examples/weather/graphcast/train_graphcast.py +++ b/examples/weather/graphcast/train_graphcast.py @@ -63,6 +63,8 @@ class GraphCastTrainer(BaseTrainer): + """GraphCast Trainer""" + def __init__(self, wb, dist, rank_zero_logger): super().__init__() self.dist = dist diff --git a/examples/weather/graphcast/validation.py b/examples/weather/graphcast/validation.py index 5eca7b8193..b11605c054 100644 --- a/examples/weather/graphcast/validation.py +++ b/examples/weather/graphcast/validation.py @@ -24,6 +24,8 @@ class Validation: + """Run validation on GraphCast model""" + def __init__(self, model, dtype, dist, wb): self.model = model self.dtype = dtype