From 297511559e17e877c04f583ab63170f2339493c5 Mon Sep 17 00:00:00 2001 From: ivanauyeung <106317256+ivanauyeung@users.noreply.github.com> Date: Wed, 4 Dec 2024 04:39:58 +0800 Subject: [PATCH] DlWP HEALpix unifying original and coupled model config (#678) * Merge dlwp_healpix and dlwp_healpix_coupled Signed-off-by: root --------- Signed-off-by: root Co-authored-by: David Pruitt --- examples/weather/dlwp_healpix/README.md | 29 +- .../config_hpx32_coupled_dlom.yaml} | 38 +- .../config_hpx32_coupled_dlwp.yaml | 0 .../configs/callbacks/early_stopping.yaml | 0 .../configs/config_hpx32_coupled_dlom.yaml | 0 .../configs/config_hpx32_coupled_dlwp.yaml} | 6 +- .../data/era5_hpx32_8var-coupled_6h_24h.yaml | 0 .../era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml | 0 .../data/module/atmos_ConstantCoupling.yaml | 0 .../configs/data/module/sst-z1000-ws.yaml | 0 .../configs/data/scaling/hpx32.yaml | 15 + .../configs/data/scaling/hpx64.yaml | 0 .../configs/data/splits/large_test.yaml | 0 .../data/splits/large_test_1950-2022.yaml | 0 .../configs/model/coupled_hpx_rec_unet.yaml | 0 .../configs/model/coupled_hpx_unet_dlom.yaml} | 6 +- .../blocks/symmetric_conv_next_block.yaml | 0 .../decoder_symmetric-conv_90-90-180.yaml | 0 .../model/modules/decoder/unet_dec.yaml | 0 .../encoder_symmetric-conv_180-90-90.yaml | 0 .../model/modules/encoder/unet_enc.yaml | 0 .../criterion/hpx32_coupled-atmos.yaml | 0 .../configs/trainer/criterion/hpx64_7var.yaml | 0 .../configs/trainer/criterion/ocean_mse.yaml | 2 +- .../configs/trainer/dlom.yaml | 0 .../configs/trainer/dlwp.yaml | 0 .../trainer/lr_scheduler/constant.yaml | 0 .../configs/trainer/lr_scheduler/plateau.yaml | 0 .../weather/dlwp_healpix_coupled/README.md | 29 - .../callbacks/learning_rate_monitor.yaml | 19 - .../configs/callbacks/model_checkpoint.yaml | 23 - .../configs/callbacks/swa.yaml | 19 - .../configs/data/era5_hpx32_7var_6h_24h.yaml | 48 -- .../configs/data/era5_hpx64_7var_6h_24h.yaml | 48 -- .../configs/data/module/time_series.yaml | 42 -- .../configs/data/scaling/classic.yaml | 41 -- .../configs/data/scaling/hpx32.yaml | 66 --- .../configs/data/scaling/zeros.yaml | 41 -- .../configs/data/splits/1959-1998.yaml | 22 - .../configs/data/splits/1964-2003.yaml | 22 - .../configs/data/splits/default.yaml | 22 - .../modules/activations/capped_gelu.yaml | 18 - .../activations/capped_leaky_relu.yaml | 18 - .../model/modules/blocks/avg_pool.yaml | 18 - .../modules/blocks/basic_conv_block.yaml | 26 - .../model/modules/blocks/conv_gru_block.yaml | 20 - .../model/modules/blocks/conv_next_block.yaml | 26 - .../model/modules/blocks/output_layer.yaml | 25 - .../blocks/transposed_conv_upsample.yaml | 23 - .../model/modules/decoder/rec_unet_dec.yaml | 32 - .../model/modules/encoder/rec_unet_enc.yaml | 31 - .../configs/model/modules/loss/mse.yaml | 17 - .../configs/model/modules/loss/mse_ssim.yaml | 30 - .../configs/trainer/criterion/mse.yaml | 17 - .../configs/trainer/criterion/ocean_mse.yaml | 20 - .../trainer/criterion/weighted_mse.yaml | 26 - .../configs/trainer/default.yaml | 28 - .../configs/trainer/lr_scheduler/cosine.yaml | 22 - .../configs/trainer/optimizer/adam.yaml | 18 - .../weather/dlwp_healpix_coupled/train.py | 175 ------ .../weather/dlwp_healpix_coupled/trainer.py | 560 ------------------ .../weather/dlwp_healpix_coupled/utils.py | 125 ---- 62 files changed, 71 insertions(+), 1742 deletions(-) rename examples/weather/{dlwp_healpix_coupled/configs/model/hpx_rec_unet.yaml => dlwp_healpix/config_hpx32_coupled_dlom.yaml} (58%) rename examples/weather/{dlwp_healpix_coupled/configs => dlwp_healpix}/config_hpx32_coupled_dlwp.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/callbacks/early_stopping.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/config_hpx32_coupled_dlom.yaml (100%) rename examples/weather/{dlwp_healpix_coupled/configs/config.yaml => dlwp_healpix/configs/config_hpx32_coupled_dlwp.yaml} (92%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/era5_hpx32_8var-coupled_6h_24h.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/module/atmos_ConstantCoupling.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/module/sst-z1000-ws.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/scaling/hpx64.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/splits/large_test.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/data/splits/large_test_1950-2022.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/coupled_hpx_rec_unet.yaml (100%) rename examples/weather/{dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet_dlom.yaml => dlwp_healpix/configs/model/coupled_hpx_unet_dlom.yaml} (88%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/modules/blocks/symmetric_conv_next_block.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/modules/decoder/decoder_symmetric-conv_90-90-180.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/modules/decoder/unet_dec.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/modules/encoder/encoder_symmetric-conv_180-90-90.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/model/modules/encoder/unet_enc.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/criterion/hpx32_coupled-atmos.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/criterion/hpx64_7var.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/dlom.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/dlwp.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/lr_scheduler/constant.yaml (100%) rename examples/weather/{dlwp_healpix_coupled => dlwp_healpix}/configs/trainer/lr_scheduler/plateau.yaml (100%) delete mode 100644 examples/weather/dlwp_healpix_coupled/README.md delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/callbacks/learning_rate_monitor.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/callbacks/model_checkpoint.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/callbacks/swa.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_7var_6h_24h.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx64_7var_6h_24h.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/module/time_series.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/scaling/classic.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx32.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/scaling/zeros.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/splits/1959-1998.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/splits/1964-2003.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/data/splits/default.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_gelu.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_leaky_relu.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/avg_pool.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/basic_conv_block.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_gru_block.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_next_block.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/output_layer.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/transposed_conv_upsample.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/rec_unet_dec.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/rec_unet_enc.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse_ssim.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/mse.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/ocean_mse.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/weighted_mse.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/default.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/cosine.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/configs/trainer/optimizer/adam.yaml delete mode 100644 examples/weather/dlwp_healpix_coupled/train.py delete mode 100644 examples/weather/dlwp_healpix_coupled/trainer.py delete mode 100644 examples/weather/dlwp_healpix_coupled/utils.py diff --git a/examples/weather/dlwp_healpix/README.md b/examples/weather/dlwp_healpix/README.md index dbbf64ed90..f773ef5d39 100644 --- a/examples/weather/dlwp_healpix/README.md +++ b/examples/weather/dlwp_healpix/README.md @@ -5,7 +5,8 @@ This example is an implementation of the model. The DLWP model can be used to predict the state of the atmosphere given a previous atmospheric state. You can infer a 320-member ensemble set of six-week forecasts at 1.4° resolution within a couple of minutes, demonstrating the potential of AI in developing -near real-time digital twins for weather prediction +near real-time digital twins for weather prediction. This example also contains an +implementation of the coupled Ocean-Atmosphere DLWP model. ## Problem overview @@ -13,4 +14,28 @@ The goal is to train an AI model that can emulate the state of the atmosphere an global weather over a certain time span. The Deep Learning Weather Prediction (DLWP) model uses deep CNNs for globally gridded weather prediction. DLWP CNNs directly map u(t) to its future state u(t+Δt) by learning from historical observations of the weather, -with Δt set to 6 hr +with Δt set to 6 hr. The Deep Learning Ocean Model (DLOM) that is designed to couple with +deep learning weather prediction (DLWP) model. The DLOM forecasts sea surface +temperature (SST). DLOMs use deep learning techniques as in DLWP models but are +configured with different architectures and slower time stepping. DLOMs and DLWP models +are trained to learn atmosphere-ocean coupling. + +## Getting Started + +To train the DLWP HEALPix model, run + +```bash +python train.py +``` + +To train the coupled DLWP model, run + +```bash +python train.py --config-name config_hpx32_coupled_dlwp +``` + +To train the coupled DLOM model, run + +```bash +python train.py --config-name config_hpx32_coupled_dlom +``` diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/hpx_rec_unet.yaml b/examples/weather/dlwp_healpix/config_hpx32_coupled_dlom.yaml similarity index 58% rename from examples/weather/dlwp_healpix_coupled/configs/model/hpx_rec_unet.yaml rename to examples/weather/dlwp_healpix/config_hpx32_coupled_dlom.yaml index 1bebf83a47..de324a48f7 100644 --- a/examples/weather/dlwp_healpix_coupled/configs/model/hpx_rec_unet.yaml +++ b/examples/weather/dlwp_healpix/config_hpx32_coupled_dlom.yaml @@ -15,22 +15,28 @@ # limitations under the License. defaults: - - modules/encoder@encoder: rec_unet_enc - - modules/decoder@decoder: rec_unet_dec + - _self_ + - data: era5_hpx32_dlom_sst-z1000-ws_48H-dt + - model: coupled_hpx_unet_dlom + - trainer: dlom -_target_: modulus.models.dlwp_healpix.HEALPixRecUNet -_recursive_: false -presteps: 1 -input_time_dim: ${data.input_time_dim} -output_time_dim: ${data.output_time_dim} -delta_time: ${data.time_step} +experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S} +output_dir: outputs/${experiment_name} +# checkpoints names are in the form training-state-.mdlus +checkpoint_name: last +load_weights_only: false +seed: 0 -# Parameters automatically overridden in train code -input_channels: 7 -output_channels: 7 -n_constants: 2 -decoder_input_channels: 1 +# Training specifications +batch_size: 64 +learning_rate: 1e-4 +num_workers: 8 -# some perf parameters -enable_nhwc: false -enable_healpixpad: false +# Distributed setup (multi GPU) +port: 29450 +master_address: localhost + +hydra: + verbose: true + run: + dir: ${output_dir} diff --git a/examples/weather/dlwp_healpix_coupled/configs/config_hpx32_coupled_dlwp.yaml b/examples/weather/dlwp_healpix/config_hpx32_coupled_dlwp.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/config_hpx32_coupled_dlwp.yaml rename to examples/weather/dlwp_healpix/config_hpx32_coupled_dlwp.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/callbacks/early_stopping.yaml b/examples/weather/dlwp_healpix/configs/callbacks/early_stopping.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/callbacks/early_stopping.yaml rename to examples/weather/dlwp_healpix/configs/callbacks/early_stopping.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/config_hpx32_coupled_dlom.yaml b/examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlom.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/config_hpx32_coupled_dlom.yaml rename to examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlom.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/config.yaml b/examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlwp.yaml similarity index 92% rename from examples/weather/dlwp_healpix_coupled/configs/config.yaml rename to examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlwp.yaml index 6c9591c36d..3e2538f75d 100644 --- a/examples/weather/dlwp_healpix_coupled/configs/config.yaml +++ b/examples/weather/dlwp_healpix/configs/config_hpx32_coupled_dlwp.yaml @@ -16,9 +16,9 @@ defaults: - _self_ - - data: era5_hpx64_7var_6h_24h - - model: hpx_rec_unet - - trainer: default + - data: era5_hpx32_8var-coupled_6h_24h + - model: coupled_hpx_rec_unet + - trainer: dlwp experiment_name: ${now:%Y-%m-%d}/${now:%H-%M-%S} output_dir: outputs/${experiment_name} diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_8var-coupled_6h_24h.yaml b/examples/weather/dlwp_healpix/configs/data/era5_hpx32_8var-coupled_6h_24h.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_8var-coupled_6h_24h.yaml rename to examples/weather/dlwp_healpix/configs/data/era5_hpx32_8var-coupled_6h_24h.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml b/examples/weather/dlwp_healpix/configs/data/era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml rename to examples/weather/dlwp_healpix/configs/data/era5_hpx32_dlom_sst-z1000-ws_48H-dt.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/module/atmos_ConstantCoupling.yaml b/examples/weather/dlwp_healpix/configs/data/module/atmos_ConstantCoupling.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/module/atmos_ConstantCoupling.yaml rename to examples/weather/dlwp_healpix/configs/data/module/atmos_ConstantCoupling.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/module/sst-z1000-ws.yaml b/examples/weather/dlwp_healpix/configs/data/module/sst-z1000-ws.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/module/sst-z1000-ws.yaml rename to examples/weather/dlwp_healpix/configs/data/module/sst-z1000-ws.yaml diff --git a/examples/weather/dlwp_healpix/configs/data/scaling/hpx32.yaml b/examples/weather/dlwp_healpix/configs/data/scaling/hpx32.yaml index 6c4b37475a..65b7fd2d98 100644 --- a/examples/weather/dlwp_healpix/configs/data/scaling/hpx32.yaml +++ b/examples/weather/dlwp_healpix/configs/data/scaling/hpx32.yaml @@ -49,4 +49,19 @@ sst: tp6: mean: 0. std: 1. +z1000-24H: + mean: 936.7376098632812 + std: 883.0859375 +ws10-24H: + mean: 6.15248966217041 + std: 3.321399688720703 +ws10: + mean: 6.1497307 + std: 3.583117 +ws10-48H: + mean: 6.081215 + std: 3.1224248 +adt: + mean: 0.40309871564035454 + std: 0.6165622319307328 diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx64.yaml b/examples/weather/dlwp_healpix/configs/data/scaling/hpx64.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx64.yaml rename to examples/weather/dlwp_healpix/configs/data/scaling/hpx64.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/splits/large_test.yaml b/examples/weather/dlwp_healpix/configs/data/splits/large_test.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/splits/large_test.yaml rename to examples/weather/dlwp_healpix/configs/data/splits/large_test.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/splits/large_test_1950-2022.yaml b/examples/weather/dlwp_healpix/configs/data/splits/large_test_1950-2022.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/data/splits/large_test_1950-2022.yaml rename to examples/weather/dlwp_healpix/configs/data/splits/large_test_1950-2022.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet.yaml b/examples/weather/dlwp_healpix/configs/model/coupled_hpx_rec_unet.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet.yaml rename to examples/weather/dlwp_healpix/configs/model/coupled_hpx_rec_unet.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet_dlom.yaml b/examples/weather/dlwp_healpix/configs/model/coupled_hpx_unet_dlom.yaml similarity index 88% rename from examples/weather/dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet_dlom.yaml rename to examples/weather/dlwp_healpix/configs/model/coupled_hpx_unet_dlom.yaml index 133baaf6da..0fd797b9cd 100644 --- a/examples/weather/dlwp_healpix_coupled/configs/model/coupled_hpx_rec_unet_dlom.yaml +++ b/examples/weather/dlwp_healpix/configs/model/coupled_hpx_unet_dlom.yaml @@ -15,10 +15,10 @@ # limitations under the License. defaults: - - modules/encoder@encoder: rec_unet_enc - - modules/decoder@decoder: rec_unet_dec + - modules/encoder@encoder: unet_enc + - modules/decoder@decoder: unet_dec -_target_: modulus.models.dlwp_healpix.HEALPixRecUNet +_target_: modulus.models.dlwp_healpix.HEALPixUNet _recursive_: false presteps: 0 input_time_dim: ${data.input_time_dim} diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/symmetric_conv_next_block.yaml b/examples/weather/dlwp_healpix/configs/model/modules/blocks/symmetric_conv_next_block.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/symmetric_conv_next_block.yaml rename to examples/weather/dlwp_healpix/configs/model/modules/blocks/symmetric_conv_next_block.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/decoder_symmetric-conv_90-90-180.yaml b/examples/weather/dlwp_healpix/configs/model/modules/decoder/decoder_symmetric-conv_90-90-180.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/decoder_symmetric-conv_90-90-180.yaml rename to examples/weather/dlwp_healpix/configs/model/modules/decoder/decoder_symmetric-conv_90-90-180.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/unet_dec.yaml b/examples/weather/dlwp_healpix/configs/model/modules/decoder/unet_dec.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/unet_dec.yaml rename to examples/weather/dlwp_healpix/configs/model/modules/decoder/unet_dec.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/encoder_symmetric-conv_180-90-90.yaml b/examples/weather/dlwp_healpix/configs/model/modules/encoder/encoder_symmetric-conv_180-90-90.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/encoder_symmetric-conv_180-90-90.yaml rename to examples/weather/dlwp_healpix/configs/model/modules/encoder/encoder_symmetric-conv_180-90-90.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/unet_enc.yaml b/examples/weather/dlwp_healpix/configs/model/modules/encoder/unet_enc.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/unet_enc.yaml rename to examples/weather/dlwp_healpix/configs/model/modules/encoder/unet_enc.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/hpx32_coupled-atmos.yaml b/examples/weather/dlwp_healpix/configs/trainer/criterion/hpx32_coupled-atmos.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/hpx32_coupled-atmos.yaml rename to examples/weather/dlwp_healpix/configs/trainer/criterion/hpx32_coupled-atmos.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/hpx64_7var.yaml b/examples/weather/dlwp_healpix/configs/trainer/criterion/hpx64_7var.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/hpx64_7var.yaml rename to examples/weather/dlwp_healpix/configs/trainer/criterion/hpx64_7var.yaml diff --git a/examples/weather/dlwp_healpix/configs/trainer/criterion/ocean_mse.yaml b/examples/weather/dlwp_healpix/configs/trainer/criterion/ocean_mse.yaml index ab4cd1a1bf..eb5a908ee9 100644 --- a/examples/weather/dlwp_healpix/configs/trainer/criterion/ocean_mse.yaml +++ b/examples/weather/dlwp_healpix/configs/trainer/criterion/ocean_mse.yaml @@ -16,5 +16,5 @@ _target_: modulus.metrics.climate.healpix_loss.OceanMSE -lsm_file: /invariants/hpx32_1950-2022_3h_sst-only.zarr +lsm_file: /datasets/healpix/HPX32/hpx32_1950-2022_3h_sst_coupled.zarr diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/dlom.yaml b/examples/weather/dlwp_healpix/configs/trainer/dlom.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/dlom.yaml rename to examples/weather/dlwp_healpix/configs/trainer/dlom.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/dlwp.yaml b/examples/weather/dlwp_healpix/configs/trainer/dlwp.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/dlwp.yaml rename to examples/weather/dlwp_healpix/configs/trainer/dlwp.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/constant.yaml b/examples/weather/dlwp_healpix/configs/trainer/lr_scheduler/constant.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/constant.yaml rename to examples/weather/dlwp_healpix/configs/trainer/lr_scheduler/constant.yaml diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/plateau.yaml b/examples/weather/dlwp_healpix/configs/trainer/lr_scheduler/plateau.yaml similarity index 100% rename from examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/plateau.yaml rename to examples/weather/dlwp_healpix/configs/trainer/lr_scheduler/plateau.yaml diff --git a/examples/weather/dlwp_healpix_coupled/README.md b/examples/weather/dlwp_healpix_coupled/README.md deleted file mode 100644 index 61f188bf20..0000000000 --- a/examples/weather/dlwp_healpix_coupled/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Deep Learning Weather Prediction (DLWP) model for weather forecasting - -This example is an implementation of the coupled Ocean-Atmosphere DLWP model. - -## Problem overview - -The goal is to train an AI model that can emulate the state of the atmosphere and predict -global weather over a certain time span. The Deep Learning Weather Prediction (DLWP) model -uses deep CNNs for globally gridded weather prediction. DLWP CNNs directly map u(t) to -its future state u(t+Δt) by learning from historical observations of the weather, -with Δt set to 6 hr. The Deep Learning Ocean Model (DLOM) that is designed to couple with -deep learning weather prediction (DLWP) model. The DLOM forecasts sea surface -temperature (SST). DLOMs use deep learning techniques as in DLWP models but are -configured with different architectures and slower time stepping. DLOMs and DLWP models -are trained to learn atmosphere-ocean coupling. - -## Getting Started - -To train the coupled DLWP model, run - -```bash -python train.py --config-name config_hpx32_coupled_dlwp -``` - -To train the coupled DLOM model, run - -```bash -python train.py --config-name config_hpx32_coupled_dlom -``` diff --git a/examples/weather/dlwp_healpix_coupled/configs/callbacks/learning_rate_monitor.yaml b/examples/weather/dlwp_healpix_coupled/configs/callbacks/learning_rate_monitor.yaml deleted file mode 100644 index 5133d4b0df..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/callbacks/learning_rate_monitor.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -learning_rate_monitor: - _target_: pytorch_lightning.callbacks.LearningRateMonitor - logging_interval: epoch \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/callbacks/model_checkpoint.yaml b/examples/weather/dlwp_healpix_coupled/configs/callbacks/model_checkpoint.yaml deleted file mode 100644 index cff4cce20f..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/callbacks/model_checkpoint.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -model_checkpoint: - _target_: pytorch_lightning.callbacks.ModelCheckpoint - filename: '{epoch:03d}-{val_loss:.4E}' - monitor: 'val_loss' - mode: 'min' - save_top_k: 10 - save_last: True \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/callbacks/swa.yaml b/examples/weather/dlwp_healpix_coupled/configs/callbacks/swa.yaml deleted file mode 100644 index 9b13af08be..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/callbacks/swa.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -stochastic_weight_avg: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 5 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_7var_6h_24h.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_7var_6h_24h.yaml deleted file mode 100644 index 0cf59d59ab..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx32_7var_6h_24h.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - module: time_series - - scaling: classic - - splits: default - -src_directory: /datasets/healpix/HPX32 -dst_directory: /datasets/healpix/HPX32 -dataset_name: era5_hpx32_7var_6h_24h -prefix: era5_1deg_3h_HPX32_1979-2021_ -suffix: '' -data_format: classic -input_variables: - - z500 - - tau300-700 - - z1000 - - t2m0 - - tcwv0 - - t850 - - z250 -output_variables: null -constants: - land_sea_mask: lsm - topography: z -input_time_dim: 2 -output_time_dim: 4 -data_time_step: 3h -time_step: 6h -gap: 6h -add_insolation: true -nside: 32 -cube_dim: ${data.nside} -prebuilt_dataset: true diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx64_7var_6h_24h.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx64_7var_6h_24h.yaml deleted file mode 100644 index dd6f4b133a..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/era5_hpx64_7var_6h_24h.yaml +++ /dev/null @@ -1,48 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - module: time_series - - scaling: classic - - splits: default - -src_directory: /datasets/healpix/HPX64 -dst_directory: /datasets/healpix/HPX64 -dataset_name: era5_hpx64_7var_6h_24h -prefix: era5_0.25deg_3h_HPX64_1979-2021_ -suffix: '' -data_format: classic -input_variables: - - z500 - - tau300-700 - - z1000 - - t2m0 - - tcwv0 - - t850 - - z250 -output_variables: null -constants: - land_sea_mask: lsm - topography: z -input_time_dim: 2 -output_time_dim: 4 -data_time_step: 3h -time_step: 6h -gap: 6h -add_insolation: true -nside: 64 -cube_dim: ${data.nside} -prebuilt_dataset: true diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/module/time_series.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/module/time_series.yaml deleted file mode 100644 index f2ee59b2a2..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/module/time_series.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.datapipes.healpix.data_modules.TimeSeriesDataModule -src_directory: ${data.src_directory} -dst_directory: ${data.dst_directory} -dataset_name: ${data.dataset_name} -prefix: ${data.prefix} -suffix: ${data.suffix} -data_format: ${data.data_format} -batch_size: ${batch_size} -drop_last: true -input_variables: ${data.input_variables} -output_variables: ${data.output_variables} -constants: ${data.constants} -scaling: ${data.scaling} -splits: ${data.splits} -presteps: ${model.presteps} -input_time_dim: ${data.input_time_dim} -output_time_dim: ${data.output_time_dim} -data_time_step: ${data.data_time_step} -time_step: ${data.time_step} -gap: ${data.gap} -shuffle: true -add_insolation: ${data.add_insolation} -cube_dim: ${data.cube_dim} -num_workers: ${num_workers} -pin_memory: true -prebuilt_dataset: ${data.prebuilt_dataset} \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/classic.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/scaling/classic.yaml deleted file mode 100644 index 65b86cedd0..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/classic.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -t2m0: - mean: 287.8665771484375 - std: 14.86227798461914 -t850: - mean: 281.2710266113281 - std: 12.04991626739502 -tau300-700: - mean: 61902.72265625 - std: 2559.8408203125 -tcwv0: - mean: 24.034976959228516 - std: 16.411935806274414 -z1000: - mean: 952.1435546875 - std: 895.7516479492188 -z250: - mean: 101186.28125 - std: 5551.77978515625 -z500: - mean: 55625.9609375 - std: 2681.712890625 -tp6: - mean: 0. - std: 1. - log_epsilon: 1e-6 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx32.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx32.yaml deleted file mode 100644 index adc3b0e03c..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/hpx32.yaml +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -t2m0: - mean: 287.8665771484375 - std: 14.86227798461914 -t850: - mean: 281.2710266113281 - std: 12.04991626739502 -tau300-700: - mean: 61902.72265625 - std: 2559.8408203125 -tcwv0: - mean: 24.034976959228516 - std: 16.411935806274414 -z1000: - mean: 952.1435546875 - std: 895.7516479492188 -z1000-48H: - mean: 934.4945 - std: 842.1188 -z250: - mean: 101186.28125 - std: 5551.77978515625 -z500: - mean: 55625.9609375 - std: 2681.712890625 -# calculated with data from 1979-2018 -sst-ti: - mean: 290.53864 - std: 10.5835 -# calculated with data from 1950-2022 -sst: - mean: 290.64487 - std: 10.5792 -tp6: - mean: 0. - std: 1. -z1000-24H: - mean: 936.7376098632812 - std: 883.0859375 -ws10-24H: - mean: 6.15248966217041 - std: 3.321399688720703 -ws10: - mean: 6.1497307 - std: 3.583117 -ws10-48H: - mean: 6.081215 - std: 3.1224248 -adt: - mean: 0.40309871564035454 - std: 0.6165622319307328 diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/zeros.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/scaling/zeros.yaml deleted file mode 100644 index 5b12f9935d..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/scaling/zeros.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -z500: - mean: 0. - std: 1. -z1000: - mean: 0. - std: 1. -tau300-700: - mean: 0. - std: 1. -t2m0: - mean: 0. - std: 1. -tcwv0: - mean: 0. - std: 1. -t850: - mean: 0. - std: 1. -z250: - mean: 0. - std: 1. -tp6: - mean: 0. - std: 1. - log_epsilon: 1e-6 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/splits/1959-1998.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/splits/1959-1998.yaml deleted file mode 100644 index 81c8129435..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/splits/1959-1998.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -train_date_start: 1959-01-01 -train_date_end: 1998-12-31T18:00 -val_date_start: 1999-01-01 -val_date_end: 2000-12-31T18:00 -test_date_start: 2017-01-01 -test_date_end: 2018-12-31T18:00 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/splits/1964-2003.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/splits/1964-2003.yaml deleted file mode 100644 index ad0e4aa0d4..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/splits/1964-2003.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -train_date_start: 1964-01-01 -train_date_end: 2003-12-31T18:00 -val_date_start: 2004-01-01 -val_date_end: 2005-12-31T18:00 -test_date_start: 2017-01-01 -test_date_end: 2018-12-31T18:00 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/data/splits/default.yaml b/examples/weather/dlwp_healpix_coupled/configs/data/splits/default.yaml deleted file mode 100644 index da8ad7eaa6..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/data/splits/default.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -train_date_start: 1979-01-01 -train_date_end: 2012-12-31T18:00 -val_date_start: 2013-01-01 -val_date_end: 2016-12-31T18:00 -test_date_start: 2017-01-01 -test_date_end: 2018-12-31T18:00 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_gelu.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_gelu.yaml deleted file mode 100644 index c27f1f3953..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_gelu.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.models.layers.activations.CappedGELU -cap_value: 10 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_leaky_relu.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_leaky_relu.yaml deleted file mode 100644 index 18020f4357..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/activations/capped_leaky_relu.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.models.layers.activations.CappedLeakyReLU -cap_value: 10 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/avg_pool.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/avg_pool.yaml deleted file mode 100644 index 40a07bc208..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/avg_pool.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.AvgPool -pooling: 2 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/basic_conv_block.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/basic_conv_block.yaml deleted file mode 100644 index 9e57d46877..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/basic_conv_block.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - /model/modules/activations@activation: capped_gelu - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.BasicConvBlock -_recursive_: true -in_channels: 3 -out_channels: 1 -kernel_size: 3 -dilation: 1 -n_layers: 1 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_gru_block.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_gru_block.yaml deleted file mode 100644 index 64670db177..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_gru_block.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.ConvGRUBlock -_recursive_: false -in_channels: 3 -kernel_size: 1 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_next_block.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_next_block.yaml deleted file mode 100644 index 96eacea0df..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/conv_next_block.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - /model/modules/activations@activation: capped_gelu - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.ConvNeXtBlock -_recursive_: true -in_channels: 3 -out_channels: 1 -kernel_size: 3 -dilation: 1 -upscale_factor: 4 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/output_layer.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/output_layer.yaml deleted file mode 100644 index 001d399204..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/output_layer.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - activation: null - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.BasicConvBlock -in_channels: 3 -out_channels: 2 -kernel_size: 1 -dilation: 1 -n_layers: 1 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/transposed_conv_upsample.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/transposed_conv_upsample.yaml deleted file mode 100644 index 7cf1e7e330..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/blocks/transposed_conv_upsample.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - /model/modules/activations@activation: capped_gelu - -_target_: modulus.models.dlwp_healpix_layers.healpix_blocks.TransposedConvUpsample -in_channels: 3 -out_channels: 1 -upsampling: 2 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/rec_unet_dec.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/rec_unet_dec.yaml deleted file mode 100644 index e816b7f9ca..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/decoder/rec_unet_dec.yaml +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - /model/modules/blocks@conv_block: conv_next_block - - /model/modules/blocks@up_sampling_block: transposed_conv_upsample - - /model/modules/blocks@recurrent_block: conv_gru_block - - /model/modules/blocks@output_layer: output_layer - -_target_: modulus.models.dlwp_healpix_layers.healpix_decoder.UNetDecoder -_recursive_: false -n_channels: - - 34 - - 68 - - 136 -dilations: - - 4 - - 2 - - 1 diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/rec_unet_enc.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/rec_unet_enc.yaml deleted file mode 100644 index e92a2f98ee..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/encoder/rec_unet_enc.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - /model/modules/blocks@conv_block: conv_next_block - - /model/modules/blocks@down_sampling_block: avg_pool - - /model/modules/blocks@recurrent_block: conv_gru_block - -_target_: modulus.models.dlwp_healpix_layers.healpix_encoder.UNetEncoder -_recursive_: false -n_channels: - - 136 - - 68 - - 34 -dilations: - - 1 - - 2 - - 4 \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse.yaml deleted file mode 100644 index cee82c7ca0..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: torch.nn.MSELoss \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse_ssim.yaml b/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse_ssim.yaml deleted file mode 100644 index 00f0d87f01..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/model/modules/loss/mse_ssim.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.metrics.climate.loss.MSE_SSIM - -mse_params: -ssim_params: - window_size: 11 - time_series_forecasting: True -# variables over which to calculate SSIM-MSE weighted average loss -ssim_variables: - - ttr1h - - tcwv0 -# used to calculated the weighted average between MSE and DSSIM -weights: - - 0. - - 1. diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/mse.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/mse.yaml deleted file mode 100644 index a6b617f3b1..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/mse.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.metrics.climate.healpix_loss.BaseMSE diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/ocean_mse.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/ocean_mse.yaml deleted file mode 100644 index eb5a908ee9..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/ocean_mse.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.metrics.climate.healpix_loss.OceanMSE - -lsm_file: /datasets/healpix/HPX32/hpx32_1950-2022_3h_sst_coupled.zarr - diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/weighted_mse.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/weighted_mse.yaml deleted file mode 100644 index ff9df3a269..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/criterion/weighted_mse.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: modulus.metrics.climate.healpix_loss.WeightedMSE - -weights: - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 - - 1.0 diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/default.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/default.yaml deleted file mode 100644 index 205501702f..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/default.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -defaults: - - criterion: mse - - optimizer: adam - - lr_scheduler: cosine - -_target_: trainer.Trainer -_recursive_: true -max_epochs: 300 -early_stopping_patience: null -amp_mode: "fp16" -graph_mode: "train_eval" -output_dir: ${output_dir} diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/cosine.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/cosine.yaml deleted file mode 100644 index 20d290951c..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/lr_scheduler/cosine.yaml +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: torch.optim.lr_scheduler.CosineAnnealingLR -optimizer: ${model.optimizer} -T_max: ${trainer.max_epochs} -eta_min: 4e-5 -last_epoch: -1 -verbose: false \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/configs/trainer/optimizer/adam.yaml b/examples/weather/dlwp_healpix_coupled/configs/trainer/optimizer/adam.yaml deleted file mode 100644 index fc09755d67..0000000000 --- a/examples/weather/dlwp_healpix_coupled/configs/trainer/optimizer/adam.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -_target_: torch.optim.Adam -lr: ${learning_rate} \ No newline at end of file diff --git a/examples/weather/dlwp_healpix_coupled/train.py b/examples/weather/dlwp_healpix_coupled/train.py deleted file mode 100644 index 28ab2445ee..0000000000 --- a/examples/weather/dlwp_healpix_coupled/train.py +++ /dev/null @@ -1,175 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import sys - -import hydra -import numpy as np -import torch as th -from modulus.distributed import DistributedManager -from hydra.utils import instantiate - -from modulus import Module -from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper - -from pathlib import Path - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - - -@hydra.main(config_path="./configs", config_name="config", version_base=None) -def train(cfg): - """Train DLWP HEALPix weather model using the techniques described in the - paper "Advancing Parsimonious Deep Learning Weather Prediction using the HEALPix Mesh". - """ - # Initialize distributed - DistributedManager.initialize() - dist = DistributedManager() - - # set device globally to be sure that no spurious context are created on gpu 0: - th.cuda.set_device(dist.device) - - # Initialize logger. - os.makedirs(".logs", exist_ok=True) - logger = PythonLogger(name="train") # General python logger - logger0 = RankZeroLoggingWrapper(logger, dist) - logger.file_logging(file_name=f".logs/train_{dist.rank}.log") - logger0.info(f"experiment working directory: {os.getcwd()}") - - # Seed - if cfg.seed is not None: - th.manual_seed(cfg.seed) - if th.cuda.is_available(): - th.cuda.manual_seed(cfg.seed) - np.random.seed(cfg.seed) - - # Data module - data_module = instantiate(cfg.data.module) - - # Model - input_channels = len(cfg.data.input_variables) - output_channels = ( - len(cfg.data.output_variables) - if cfg.data.output_variables is not None - else input_channels - ) - constants_arr = data_module.constants - n_constants = ( - 0 if constants_arr is None else len(constants_arr.keys()) - ) # previously was 0 but with new format it is 1 - - decoder_input_channels = int(cfg.data.get("add_insolation", 0)) - cfg.model["input_channels"] = input_channels - cfg.model["output_channels"] = output_channels - cfg.model["n_constants"] = n_constants - cfg.model["decoder_input_channels"] = decoder_input_channels - - # convert Hydra cfg to pure dicts so they can be saved using modulus - model = instantiate(cfg.model, _convert_="all") - model.batch_size = cfg.batch_size - model.learning_rate = cfg.learning_rate - - # Instantiate PyTorch modules (with state dictionaries from checkpoint if given) - criterion = instantiate(cfg.trainer.criterion) - optimizer = instantiate(cfg.trainer.optimizer, params=model.parameters()) - lr_scheduler = ( - instantiate(cfg.trainer.lr_scheduler, optimizer=optimizer) - if cfg.trainer.lr_scheduler is not None - else None - ) - - # setup startup values - epoch = 1 - val_error = th.inf - iteration = 0 - epochs_since_improved = 0 - - # Prepare training under consideration of checkpoint if given - if cfg.get("checkpoint_name", None) is not None: - checkpoint_path = Path( - cfg.get("output_dir"), - "tensorboard", - "checkpoints", - "training-state-" + cfg.get("checkpoint_name") + ".mdlus", - ) - optimizer_path = Path( - cfg.get("output_dir"), - "tensorboard", - "checkpoints", - "optimizer-state-" + cfg.get("checkpoint_name") + ".ckpt", - ) - if checkpoint_path.exists(): - logger0.info(f"Loading checkpoint: {checkpoint_path}") - model = Module.from_checkpoint(str(checkpoint_path)) - checkpoint = th.load(optimizer_path, map_location=dist.device) - if not cfg.get("load_weights_only"): - # Load optimizer - optimizer = instantiate( - cfg.trainer.optimizer, params=model.parameters() - ) - optimizer_state_dict = checkpoint["optimizer_state_dict"] - optimizer.load_state_dict(optimizer_state_dict) - # Move tensors to the appropriate device as in https://github.com/pytorch/pytorch/issues/2830 - for state in optimizer.state.values(): - for k, v in state.items(): - if th.is_tensor(v): - state[k] = v.to(device=dist.device) - # Optionally load scheduler - if cfg.trainer.lr_scheduler is not None: - lr_scheduler = instantiate( - cfg.trainer.lr_scheduler, optimizer=optimizer - ) - lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) - else: - lr_scheduler = None - epoch = checkpoint["epoch"] - val_error = checkpoint["val_error"] - iteration = checkpoint["iteration"] - epochs_since_improved = ( - checkpoint["epochs_since_improved"] - if "epochs_since_improved" in checkpoint.keys() - else 0 - ) - else: - logger0.info( - f"Checkpoint not found, weights not loaded. Requested path: {checkpoint_path}" - ) - - # Instantiate the trainer and fit the model - logger0.info("Model initialized") - trainer = instantiate( - cfg.trainer, - model=model, - data_module=data_module, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - device=dist.device, - ) - logger0.info(f"starting training") - trainer.fit( - epoch=epoch, - validation_error=val_error, - iteration=iteration, - epochs_since_improved=epochs_since_improved, - ) - - -if __name__ == "__main__": - train() - print("Done.") diff --git a/examples/weather/dlwp_healpix_coupled/trainer.py b/examples/weather/dlwp_healpix_coupled/trainer.py deleted file mode 100644 index ad67a1b802..0000000000 --- a/examples/weather/dlwp_healpix_coupled/trainer.py +++ /dev/null @@ -1,560 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env python3 -import gc -import os -import threading - -import torch - -# distributed stuff -from modulus.distributed import DistributedManager -from torch.cuda import amp -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -# custom -from utils import write_checkpoint - -from modulus.launch.logging import PythonLogger, RankZeroLoggingWrapper - - -class Trainer: - """ - A class for DLWP model training - """ - - def __init__( - self, - model: torch.nn.Module, # Specify... (import) - data_module: torch.nn.Module, # Specify... (import) - criterion: torch.nn.Module, # Specify... (import) - optimizer: torch.nn.Module, # Specify... (import) - lr_scheduler: torch.nn.Module, # Specify... (import) - max_epochs: int = 500, - early_stopping_patience: int = None, - amp_mode: str = "none", - graph_mode: str = "none", - device: torch.device = torch.device("cpu"), - output_dir: str = "/outputs/", - max_norm: float = None, - ): - """ - Constructor. - - Parameters: - model: torch.nn.Module - The model to train - data_module: torch.nn.Module - The Pytorch module used for dataloading - criterion: torch.nn.Module - The PyTorch loss module to use - optimizer: torch.nn.Module - The PyTorch optimizer module to use - lr_scheduler: torch.nn.Module - The PyTorch learning rate scheduler module to use - max_epochs: int, optional - The maximum number of epochs to train for - early_stopping_patience: int, optional - amp_mode: str, optional - amp mode to use, valid options ["fp16", "bfloat16"] - graph_mode: str, optional - Where to use cudagraphs for training, valid options ["train", "train_eval"] - device: torch.device, optional - Device on which to run training on, can be any available torch.device - output_dir: str, optional - Where to store results - max_norm: float, optional - Maximum norm to use for training - """ - self.device = device - self.amp_enable = False if (amp_mode == "none") else True - self.amp_dtype = torch.float16 if (amp_mode == "fp16") else torch.bfloat16 - self.output_variables = data_module.output_variables - self.early_stopping_patience = early_stopping_patience - self.max_norm = max_norm - - self.model = model.to(device=self.device) - - self.dist = DistributedManager() - - # Initialize logger. - self.logger = PythonLogger(name="training_loop") # General python logger - self.logger0 = RankZeroLoggingWrapper(self.logger, self.dist) - self.logger.file_logging(file_name=f".logs/training_loop_{self.dist.rank}.log") - - self.dataloader_train, self.sampler_train = data_module.train_dataloader( - num_shards=self.dist.world_size, shard_id=self.dist.rank - ) - self.dataloader_valid, self.sampler_valid = data_module.val_dataloader( - num_shards=self.dist.world_size, shard_id=self.dist.rank - ) - self.output_dir_tb = os.path.join(output_dir, "tensorboard") - - # set the other parameters - self.optimizer = optimizer - # Set up criterion, pass metadata - self.criterion = criterion.to(device=self.device) - try: - self.criterion.setup(self) - except AttributeError: - raise NotImplementedError( - 'Attribute error encountered in call to criterio.setup(). \ - Could be that criterion is not compatable with custom loss dlwp training. See \ - "modulus/metrics/climate/healpix_loss.py" for proper criterion implementation examples.' - ) - - # opportunity for custom loss classes to get everything in order - self.lr_scheduler = lr_scheduler - self.max_epochs = max_epochs - - # add gradient scaler - self.gscaler = amp.GradScaler( - enabled=(self.amp_enable and self.amp_dtype == torch.float16) - ) - - # use distributed data parallel if requested: - self.print_to_screen = True - self.train_graph = None - self.eval_graph = None - - # for status bars - self.print_to_screen = self.dist.rank == 0 - - if self.dist.device.type == "cuda": - capture_stream = torch.cuda.Stream() - if torch.distributed.is_initialized(): - with torch.cuda.stream(capture_stream): - self.model = DDP( - self.model, - device_ids=[self.device.index], - output_device=[self.device.index], - broadcast_buffers=True, - find_unused_parameters=False, - gradient_as_bucket_view=True, - ) - capture_stream.synchronize() - - # capture graph if requested - if graph_mode in ["train", "train_eval"]: - self.logger0.info("Capturing model for training ...") - # get the shapes - inp, tar = next(iter(self.dataloader_train)) - - self._train_capture(capture_stream, [x.shape for x in inp], tar.shape) - - if graph_mode == "train_eval": - self.logger0.info("Capturing model for validation ...") - self._eval_capture(capture_stream) - - # Set up tensorboard summary_writer or try 'weights and biases' - # Initialize tensorbaord to track scalars - if self.dist.rank == 0: - self.writer = SummaryWriter(log_dir=self.output_dir_tb) - - def _train_capture( - self, capture_stream, inp_shapes, tar_shape, num_warmup_steps=20 - ): - # perform graph capture of the model - self.static_inp = [ - torch.zeros(x_shape, dtype=torch.float32, device=self.device) - for x_shape in inp_shapes - ] - self.static_tar = torch.zeros( - tar_shape, dtype=torch.float32, device=self.device - ) - - self.model.train() - capture_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(capture_stream): - for _ in range(num_warmup_steps): - self.model.zero_grad(set_to_none=True) - - # FW - with amp.autocast(enabled=self.amp_enable, dtype=self.amp_dtype): - self.static_gen_train = self.model.forward(self.static_inp) - - self.static_loss_train = self.criterion( - self.static_gen_train, self.static_tar - ) - - # BW - self.gscaler.scale(self.static_loss_train).backward() - - # sync here - capture_stream.synchronize() - - gc.collect() - torch.cuda.empty_cache() - - # create graph - self.train_graph = torch.cuda.CUDAGraph() - - # zero grads before capture: - self.model.zero_grad(set_to_none=True) - - # start capture - with torch.cuda.graph(self.train_graph): - # FW - with amp.autocast(enabled=self.amp_enable, dtype=self.amp_dtype): - # self.static_gen_train = self.model(self.static_inp) - self.static_gen_train = self.model.forward(self.static_inp) - - self.static_loss_train = self.criterion( - self.static_gen_train, self.static_tar - ) - - # BW - self.gscaler.scale(self.static_loss_train).backward() - - # wait for capture to finish - torch.cuda.current_stream().wait_stream(capture_stream) - - def _eval_capture(self, capture_stream, num_warmup_steps=20): - self.model.eval() - capture_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(capture_stream): - with torch.no_grad(): - for _ in range(num_warmup_steps): - # FW - with amp.autocast(enabled=self.amp_enable, dtype=self.amp_dtype): - # self.static_gen_eval = self.model(self.static_inp) - self.static_gen_eval = self.model.forward(self.static_inp) - - self.static_loss_eval = self.criterion( - self.static_gen_eval, self.static_tar - ) - # False flag for average channels ensures criterion will keep variable loss separated - self.static_losses_eval = self.criterion( - self.static_gen_eval, - self.static_tar, - average_channels=False, - ) - - # sync here - capture_stream.synchronize() - - gc.collect() - torch.cuda.empty_cache() - - # create graph - self.eval_graph = torch.cuda.CUDAGraph() - - # start capture: - with torch.cuda.graph(self.eval_graph, pool=self.train_graph.pool()): - # FW - with torch.no_grad(): - with amp.autocast(enabled=self.amp_enable, dtype=self.amp_dtype): - # self.static_gen_eval = self.model(self.static_inp) - self.static_gen_eval = self.model.forward(self.static_inp) - - self.static_loss_eval = self.criterion( - self.static_gen_eval, self.static_tar - ) - # False flag for average channels ensures criterion will keep variable loss separated - self.static_losses_eval = self.criterion( - self.static_gen_eval, - self.static_tar, - average_channels=False, - ) - - # wait for capture to finish - torch.cuda.current_stream().wait_stream(capture_stream) - - def fit( - self, - epoch: int = 0, - validation_error: torch.Tensor = torch.inf, - iteration: int = 0, - epochs_since_improved: int = 0, - ): - """ - Perform training by iterating over all epochs - - Parameters - epoch: int, optional - Current epoch number - validation_error: torch.Tensor, optional - Current best validation error - iteration: int, optional - Current iteration number - epochs_since_improved: int, optional - Number of epochs that have seen improvement in validation error - """ - best_validation_error = validation_error - for epoch in range(epoch, self.max_epochs): - torch.cuda.nvtx.range_push(f"training epoch {epoch}") - - if self.sampler_train is not None: - self.sampler_train.set_epoch(epoch) - - # Train: iterate over all training samples - training_step = 0 - self.model.train() - for inputs, target in ( - pbar := tqdm(self.dataloader_train, disable=(not self.print_to_screen)) - ): - pbar.set_description(f"Training epoch {epoch}/{self.max_epochs}") - - # Trach epoch in tensorboard - if self.dist.rank == 0: - self.writer.add_scalar( - tag="epoch", scalar_value=epoch, global_step=iteration - ) - - torch.cuda.nvtx.range_push(f"training step {training_step}") - - inputs = [x.to(device=self.device) for x in inputs] - target = target.to(device=self.device) - - # do optimizer step - if self.train_graph is not None: - # copy data into entry nodes - for idx, inp in enumerate(inputs): - self.static_inp[idx].copy_(inp) - - self.static_tar.copy_(target) - - # replay - self.train_graph.replay() - - # extract loss - output = self.static_gen_train - train_loss = self.static_loss_train - else: - # zero grads - self.model.zero_grad(set_to_none=True) - - if self.amp_enable: - with amp.autocast( - enabled=self.amp_enable, dtype=self.amp_dtype - ): - output = self.model(inputs) - train_loss = self.criterion(output, target) - else: - output = self.model(inputs) - train_loss = self.criterion(output, target) - - self.gscaler.scale(train_loss).backward() - - # Gradient clipping - self.gscaler.unscale_(self.optimizer) - try: - curr_lr = ( - self.optimizer.param_groups[-1]["lr"] - if self.lr_scheduler is None - else self.lr_scheduler.get_last_lr()[0] - ) - except ( - AttributeError - ): # try loop required since LearnOnPlateau has no "get_last_lr" attribute - curr_lr = ( - self.optimizer.param_groups[-1]["lr"] - if self.lr_scheduler is None - else self.optimizer.param_groups[0]["lr"] - ) - # check that max norm was not given to trainer - if self.max_norm is None: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), curr_lr) - else: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.max_norm - ) - - # Optimizer step - self.gscaler.step(self.optimizer) - self.gscaler.update() - - pbar.set_postfix({"Loss": train_loss.item()}) - - torch.cuda.nvtx.range_pop() - - if self.dist.rank == 0: - self.writer.add_scalar( - tag="loss", scalar_value=train_loss, global_step=iteration - ) - iteration += 1 - training_step += 1 - - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push(f"validation epoch {epoch}") - - # Validate (without gradients) - if self.sampler_valid is not None: - self.sampler_valid.set_epoch(epoch) - - self.model.eval() - with torch.no_grad(): - validation_stats = torch.zeros( - (2 + len(self.output_variables)), - dtype=torch.float32, - device=self.device, - ) - for inputs, target in ( - pbar := tqdm( - self.dataloader_valid, disable=(not self.print_to_screen) - ) - ): - pbar.set_description(f"Validation epoch {epoch}/{self.max_epochs}") - inputs = [x.to(device=self.device) for x in inputs] - target = target.to(device=self.device) - bsize = float(target.shape[0]) - - # do eval step - if self.eval_graph is not None: - # copy data into entry nodes - for idx, inp in enumerate(inputs): - self.static_inp[idx].copy_(inp) - self.static_tar.copy_(target) - - # replay graph - self.eval_graph.replay() - - # increase the loss - validation_stats[0] += self.static_loss_eval * bsize - - # Same for the per-variable loss - for v_idx, v_name in enumerate(self.output_variables): - validation_stats[1 + v_idx] += ( - self.static_losses_eval[v_idx] * bsize - ) - else: - if self.amp_enable: - with amp.autocast( - enabled=self.amp_enable, dtype=self.amp_dtype - ): - output = self.model(inputs) - validation_stats[0] += ( - self.criterion(prediction=output, target=target) - * bsize - ) - # save per variable loss - eval_losses = self.criterion( - output, target, average_channels=False - ) - for v_idx, v_name in enumerate(self.output_variables): - validation_stats[1 + v_idx] += ( - eval_losses[v_idx] * bsize - ) - else: - output = self.model(inputs) - validation_stats[0] += ( - self.criterion(prediction=output, target=target) * bsize - ) - eval_losses = self.criterion( - output, target, average_channels=False - ) - for v_idx, v_name in enumerate(self.output_variables): - validation_stats[1 + v_idx] += ( - eval_losses[v_idx] * bsize - ) - - pbar.set_postfix( - {"Loss": (validation_stats[0] / validation_stats[-1]).item()} - ) - - # increment sample counter - validation_stats[-1] += bsize - - if torch.distributed.is_initialized(): - torch.distributed.all_reduce(validation_stats) - - validation_error = (validation_stats[0] / validation_stats[-1]).item() - - # Record error per variable - validation_errors = [] - for v_idx, v_name in enumerate(self.output_variables): - validation_errors.append( - (validation_stats[1 + v_idx] / validation_stats[-1]).item() - ) - - # Track validation improvement to later check early stopping criterion - if validation_error < best_validation_error: - best_validation_error = validation_error - epochs_since_improved = 0 - else: - epochs_since_improved += 1 - - torch.cuda.nvtx.range_pop() - - # Logging and checkpoint saving - if self.dist.rank == 0: - if self.lr_scheduler is not None: - self.writer.add_scalar( - tag="learning_rate", - scalar_value=self.optimizer.param_groups[0]["lr"], - global_step=iteration, - ) - self.writer.add_scalar( - tag="val_loss", scalar_value=validation_error, global_step=iteration - ) - - # Per-variable loss - for v_idx, v_name in enumerate(self.output_variables): - self.writer.add_scalar( - tag=f"val_loss/{v_name}", - scalar_value=validation_errors[v_idx], - global_step=iteration, - ) - - # Write model checkpoint to file, using a separate thread - self.logger0.info("Writing checkpoint") - thread = threading.Thread( - target=write_checkpoint, - args=( - self.model.module - if torch.distributed.is_initialized() - else self.model, - self.optimizer, - self.lr_scheduler, - epoch, - iteration, - validation_error, - epochs_since_improved, - self.output_dir_tb, - ), - ) - thread.start() - - # Update learning rate - try: - if self.lr_scheduler is not None: - self.lr_scheduler.step() - except TypeError: # Plateau Learning rate requires val loss - if self.lr_scheduler is not None: - self.lr_scheduler.step(validation_error) - - # Check early stopping criterium - if ( - self.early_stopping_patience is not None - and epochs_since_improved >= self.early_stopping_patience - ): - self.logger0.info( - f"Hit early stopping criterium by not improving the validation error for {epochs_since_improved}" - " epochs. Finishing training." - ) - break - - # Wrap up - if self.dist.rank == 0: - try: - thread.join() - except UnboundLocalError: - pass - self.writer.flush() - self.writer.close() diff --git a/examples/weather/dlwp_healpix_coupled/utils.py b/examples/weather/dlwp_healpix_coupled/utils.py deleted file mode 100644 index 30ceb9cfbb..0000000000 --- a/examples/weather/dlwp_healpix_coupled/utils.py +++ /dev/null @@ -1,125 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import glob -import logging -import os -import re - -import numpy as np -import torch as th - -logger = logging.getLogger(__name__) - - -# TODO switch over to modulus checkpoint system -def write_checkpoint( - model, - optimizer, - scheduler, - epoch: int, - iteration: int, - val_error: float, - epochs_since_improved: int, - dst_path: str, - keep_n_checkpoints: int = 5, -): - """ - Writes a checkpoint including model, optimizer, and scheduler state dictionaries along with current epoch, - iteration, and validation error to file. - - :param model: The network model - :param optimizer: The pytorch optimizer - :param scheduler: The pytorch learning rate scheduler - :param epoch: Current training epoch - :param iteration: Current training iteration - :param val_error: The validation error of the current training - :param epochs_since_improved: The number of epochs since the validation error improved - :param dst_path: Path where the checkpoint is written to - :param keep_n_checkpoints: Number of best checkpoints that will be saved (worse checkpoints are overwritten) - """ - root_path = os.path.join( - dst_path, - "checkpoints", - ) - # root_path = os.path.dirname(ckpt_dst_path) - ckpt_dst_path = os.path.join( - root_path, - f"training-state-epoch-{str(epoch).zfill(4)}-val_loss=" - + "{:.4E}".format(val_error) - + ".mdlus", - ) - os.makedirs(root_path, exist_ok=True) - - model.save(ckpt_dst_path) - model.save(os.path.join(root_path, "training-state-last.mdlus")) - - opt_dst_path = os.path.join( - root_path, - f"optimizer-state-epoch-{str(epoch).zfill(4)}-val_loss=" - + "{:.4E}".format(val_error) - + ".ckpt", - ) - th.save( - obj={ - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "epoch": epoch + 1, - "iteration": iteration, - "val_error": val_error, - "epochs_since_improved": epochs_since_improved, - }, - f=opt_dst_path, - ) - th.save( - obj={ - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "epoch": epoch + 1, - "iteration": iteration, - "val_error": val_error, - "epochs_since_improved": epochs_since_improved, - }, - f=os.path.join(root_path, "optimizer-state-last.ckpt"), - ) - - # Only keep top n checkpoints - ckpt_paths = np.array(glob.glob(root_path + "/training-state-epoch-*.mdlus")) - if len(ckpt_paths) > keep_n_checkpoints + 1: - worst_path = "" - worst_error = -np.infty - for ckpt_path in ckpt_paths: - if "NAN" in ckpt_path: - os.remove(ckpt_path) - try: - os.remove(ckpt_path.replace("training", "optimizer")) - except FileNotFoundError: - pass - continue - # Read the scientific number from the checkpoint name and perform update if appropriate - curr_error = float( - re.findall("-?\d*\.?\d+E[+-]?\d+", os.path.basename(ckpt_path))[0] - ) - if curr_error > worst_error: - worst_path = ckpt_path - worst_error = curr_error - os.remove(worst_path) - try: - os.remove( - worst_path.replace("training", "optimizer").replace("mdlus", "ckpt") - ) - except FileNotFoundError: - pass