Implementation of MetNet 3, SOTA neural weather model out of Google Deepmind, in Pytorch
The model architecture is pretty unremarkable. It is basically a U-net with a specific well performing vision transformer. The most interesting thing about the paper may end up being the loss scaling in section 4.3.2
- StabilityAI, A16Z Open Source AI Grant Program, and 🤗 Huggingface for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
$ pip install metnet3-pytorch
import torch
from metnet3_pytorch import MetNet3
metnet3 = MetNet3(
dim = 512,
num_lead_times = 722,
lead_time_embed_dim = 32,
input_spatial_size = 624,
attn_dim_head = 8,
hrrr_channels = 617,
input_2496_channels = 2 + 14 + 1 + 2 + 20,
input_4996_channels = 16 + 1,
precipitation_target_bins = dict(
mrms_rate = 512,
mrms_accumulation = 512,
),
surface_target_bins = dict(
omo_temperature = 256,
omo_dew_point = 256,
omo_wind_speed = 256,
omo_wind_component_x = 256,
omo_wind_component_y = 256,
omo_wind_direction = 180
),
hrrr_loss_weight = 10,
hrrr_norm_strategy = 'sync_batchnorm', # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)
# inputs
lead_times = torch.randint(0, 722, (2,))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))
# targets
precipitation_targets = dict(
mrms_rate = torch.randint(0, 512, (2, 512, 512)),
mrms_accumulation = torch.randint(0, 512, (2, 512, 512)),
)
surface_targets = dict(
omo_temperature = torch.randint(0, 256, (2, 128, 128)),
omo_dew_point = torch.randint(0, 256, (2, 128, 128)),
omo_wind_speed = torch.randint(0, 256, (2, 128, 128)),
omo_wind_component_x = torch.randint(0, 256, (2, 128, 128)),
omo_wind_component_y = torch.randint(0, 256, (2, 128, 128)),
omo_wind_direction = torch.randint(0, 180, (2, 128, 128))
)
hrrr_target = torch.randn(2, 617, 128, 128)
total_loss, loss_breakdown = metnet3(
lead_times = lead_times,
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
precipitation_targets = precipitation_targets,
surface_targets = surface_targets,
hrrr_target = hrrr_target
)
total_loss.backward()
# after much training from above, you can predict as follows
metnet3.eval()
surface_preds, hrrr_pred, precipitation_preds = metnet3(
lead_times = lead_times,
hrrr_input_2496 = hrrr_input_2496,
hrrr_stale_state = hrrr_stale_state,
input_2496 = input_2496,
input_4996 = input_4996,
)
# Dict[str, Tensor], Tensor, Dict[str, Tensor]
-
figure out all the cross entropy and MSE losses
-
auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)
-
allow researcher to pass in their own normalization variables for HRRR
-
build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions
-
make sure model can be easily saved and loaded, with different ways of handling hrrr norm
-
figure out the topological embedding, consult a neural weather researcher
@article{Andrychowicz2023DeepLF,
title = {Deep Learning for Day Forecasts from Sparse Observations},
author = {Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.06079},
url = {https://api.semanticscholar.org/CorpusID:259129311}
}
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}