This is the official implementation of the NeurIPS 2023 SPOTLIGHT paper "Multi Time Scale World Models" and related models including, Ac-RKN, HiP-RSSM.
Figure: PGM of a 2 Level MTS3 (Multi Time Scale State Space Model)OpenReview | Arxiv | Poster | Long Talk
The repository is build on Python 3.10 and Pytorch 1.13.1 and we are working on a migration of Pytorch 2.1.0. All necessary packages are listed in requirements.txt.
Example installation:
conda create --name mts3 python=3.10
conda activate mts3
pip install -r requirements.txt
- Multi Time Scale World Models
- Requirements
- Table Of Contents
- Datasets
- In a Nutshell
- In Details
- MTS3 Architecture
- Building Blocks (Gaussian Transformations)
- Related Models and Baselines
- Creating New Architectures
- Metrics Used
- Future Work
- Contributing
- Citation
The datasets automatically downloaded from urls to the dataFolder/mts3_datasets_processed folder. See the readme.md in the folder for more details.
In a nutshell, here is how to run experiments on datasets used in the MTS3 paper. After installing necessary packages, go the MTS3 Folder.
To perform training and testing with MTS3 model on mobile robot dataset:
python experiments/mobileRobot/mts3_exp.py model=default_mts3
To run a baseline (let's say HiP-RSSM):
python experiments/mobileRobot/hiprssm_exp.py model=default_hiprssm
Similar commands can be used for other datasets like frankaKitchen, maze2d, halfCheetah etc. We use a large batch_size for A100 GPUs. For smaller GPUs, please reduce the batch_size in the config file.
It is recommended to read the Hydra documentation to fully understand the configuration framework. For help launching specific experiments, please file an issue. Read the experiments/readme.md for more details on how to run experiments with different hyperparameters.
MTS3
├── agent
│ ├── Infer
│ │ └── repre_infer_mts3.py - this file contains functions to perfrom inference in MTS3 model
│ │ given some input (eg: multi step predictions)
│ │
│ ├── Learn
│ │ └── repre_learn_mts3.py - this file contains the training/learning loops
│ │
│ │
│ └── worldModels
│ ├── Decoders - this folder contains the decoders of different types
│ │ └── propDecoder.py - this file contains the decoder for the proprioceptive sensor
│ │
│ │
│ ├── gaussianTransformations - this folder contains the generic gaussian layers (see layers section)
│ │ ├── gaussian_conditioning.py
│ │ └── gaussian_marginalization.py
│ │
│ │
│ └── SensorEncoders - this folder contains the encoders for the different sensor modalities
│ │ └── propEncoder.py - this file contains the encoder for the proprioceptive sensor
│ │
│ │
│ ├── MTS3.py - this file contains the MTS3 model nn.Module
│ ├── hipRSSM.py - this file contains the hipRSSM model nn.Module
│ └── acRKN.py - this file contains the acRKN model nn.Module
│
│
├── dataFolder
│ └──mts3_datasets_processed - this folder contains the datasets used in MTS3 paper (after preprocessing)
│
│
├── experiments
│ │
│ ├── mobileRobot
│ │ ├── conf - this folder contains the config files for different models
│ │ │ └── model
│ │ │ ├── default_mts3.yaml
│ │ │ ├── default_acrkn.yaml
│ │ │ ├── default_hiprssm.yaml
│ │ │ │
│ │ │ └── learn
│ │ │ ├── default.yaml
│ │ │ └── default_rnn.yaml
│ │ │
│ │ ├── mts3_exp.py
│ │ ├── acrkn_exp.py
│ │ └── hiprssm_exp.py
│ │
│ ├── logs
│ │ └── output
│ │
│ ├── saved_models
│ │
│ ├── exp_prediction_mts3.py
│ ├── exp_prediction_acrkn.py
│ └── exp_prediction_hiprssm.py
|
│
└── utils
The task predict (slow time scale) and task-conditional state predict (fast time scale) are instances of Guassian Marginalization operiation. The task update (slow time scale) and Observation update (fast time scale) are instances of Guassian Conditioning operiation.
Thus the MTS3 model can be viewed as a hierarchical composition of Gaussian Conditioning and Gaussian Marginalization operations. The building blocks of these operations are described in the next section.
The following building blocks are used in the MTS3 model to perform inference in each timescale. They can be broadly categorized into two types of layers/gaussian transformations: Gaussian Conditioning and Gaussian Marginalization. These building blocks can be used to construct MTS3 with arbitatry number of timescales.
The observation/task update and abstract action inference at every timescale are instances of this layer. It performs the posterior inference over latent state given a set of observations and the prior distribution (see the PGM below).
The predict step in every timescale is an instance of this layer. It calculates the marginal distribution of the latent state in the next timestep, given a set of causal factors (see the PGM below).
The mean and covariance of all the gaussian random variables in latent states have the following structure. The derivations are based on this factorization assumption, which allows for scalable inference without compromising on the expressiveness of the model.
The covariance matrix is represented/stored as a list
We have implementations of related models like Ac-RKN, HiP-RSSM etc. The models and readme on how these are related can be found in agent/worldModels folder. We also implement additionally baselines like GRUs, LSTM etc. in the same folder.
As you can notice, several of the models (rkn,acrkn,hiprssm,mts3) use the same set of latent gaussian transformations. One can get creative with this and make new model architectures. For example, adding more hierachies.
We use a sliding window rmse as the metrics when calculating the rmse/nll for multistep predictions as reported in Figure 3 and Table 1 of the MTS3 paper. The multistep rmse for a timestep t is taken as the rmse over the last "window_len" time steps from t. The metrics can be found in utils/metrics.py,
with function names like sliding_window_rmse
and sliding_window_nll
. The negative liklihood (nll) reported in the table use the same sliding window approach, but only report results of the last timestep.
Note: Rmse results are reported after coverting the normalized predictions to original scale. The normalization constants are stored in the normalizer
key in the data dictionary. See the readme in the data folder for more details.
Nll results are reported in the normalized scale.
We are working on transition to pytorch 2.0, and adding Transformer baselines and unactuated MTS3 (without actions for timeseries/video prediction) etc.
Any kind of enhancement or contribution is welcomed.
If you use this codebase, or otherwise found our work valuable, please cite MTS3 and other relevant papers.
@inproceedings{shaj2023multi,
title={Multi Time Scale World Models},
author={Shaj, Vaisakh and ZADEH, Saleh GHOLAM and Demir, Ozan and Douat, Luiz Ricardo and Neumann, Gerhard},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023}
}