Code for ForeCasting with Factorized Attention (CaFA) arxiv
We provided a yaml file for the conda environment used for this project.
conda env create -f environment.yml
The ERA5 reanalysis data courtesy is under Copernicus Climate Data Store. WeatherBench2 provides processed version of it with different resolutions (in .zarr
format).
Data and normalization statistics in .npy
format that is used in this project are provided below, which is derived from WeatherBench2's dataset.
Resolution | Train range | Validation range | link |
---|---|---|---|
240 x 121 | 1979-2018 | 2019 | link |
64 x 32 | 1979-2015 | 2016 | link |
(Note that the final model on 240 x 121 is trained with year 2019 in training data.)
The configuration for training are provided under configs
directory.
Training a 100M parameter CaFA on 64 x 32 resolution will take around 15 hr for stage 1 and 50 hr for stage 2 on a 3090.
Stage 1 training example (on 64 x 32 resolution):
bash run_stage1.sh
Stage 2 fine tuning example (on 64 x 32 resolution):
bash run_stage2.sh
The pre-trained checkpoints can be downloaded through belowing links
Resolution | Train range | # of Parameter | link |
---|---|---|---|
240 x 121 | 1979-2019 | ~200M | link |
64 x 32 | 1979-2015 | ~100M | link |
To run model inference on processed npy files, please refer to validate_loop()
function under dynamics_training_loop.py
.
Here we povide a demo ipynb to showcase how to run the model on weatherbench2's data, check: inference_demo.ipynb
.
The ERA data used in the project is from European Centre for Medium-Range Weather Forecasts. WeatherBench2 has provided processed and aggregated versions, which is publicly available at link.
The spherical harmonics implementation is taken from se3_transformer_pytorch/spherical_harmonics.py.
If you find this project useful, please kindly consider citing our work:
@misc{li2024cafa,
title={CaFA: Global Weather Forecasting with Factorized Attention on Sphere},
author={Zijie Li and Anthony Zhou and Saurabh Patil and Amir Barati Farimani},
year={2024},
eprint={2405.07395},
archivePrefix={arXiv},
primaryClass={cs.LG}
}